Source code for aijack.collaborative.core.server

from abc import abstractmethod

import torch

[docs]class BaseServer(torch.nn.Module): """Abstract class for the server of the collaborative learning. Args: clients (List[BaseClient]): a list of clients server_model (torch.nn.Module): a global model server_id (int, optional): the id of this server. Defaults to 0. """ def __init__(self, clients, server_model, server_id=0): """Initialie BaseServer""" super(BaseServer, self).__init__() self.clients = clients self.server_id = server_id self.server_model = server_model self.num_clients = len(clients)
[docs] def forward(self, x): return self.server_model(x)
[docs] def train(self): self.server_model.train()
[docs] def eval(self): self.server_model.eval()
[docs] @abstractmethod def action(self): """Execute thr routine of each communication.""" pass
[docs] @abstractmethod def update(self): """Update the global model.""" pass
[docs] @abstractmethod def distribute(self): """Distribute the global model to each client.""" pass