Source code for aijack.collaborative.core.server
from abc import abstractmethod
import torch
[docs]class BaseServer(torch.nn.Module):
"""Abstract class for the server of the collaborative learning.
Args:
clients (List[BaseClient]): a list of clients
server_model (torch.nn.Module): a global model
server_id (int, optional): the id of this server. Defaults to 0.
"""
def __init__(self, clients, server_model, server_id=0):
"""Initialie BaseServer"""
super(BaseServer, self).__init__()
self.clients = clients
self.server_id = server_id
self.server_model = server_model
self.num_clients = len(clients)
[docs] def forward(self, x):
return self.server_model(x)
[docs] def train(self):
self.server_model.train()
[docs] def eval(self):
self.server_model.eval()
[docs] @abstractmethod
def action(self):
"""Execute thr routine of each communication."""
pass
[docs] @abstractmethod
def update(self):
"""Update the global model."""
pass
[docs] @abstractmethod
def distribute(self):
"""Distribute the global model to each client."""
pass