Source code for aijack.collaborative.core.api

import copy
from abc import abstractmethod

from ...utils import accuracy_torch_dataloader


[docs]class BaseFedAPI: """Abstract class for Federated Learning API"""
[docs] @abstractmethod def local_train(self): pass
[docs] @abstractmethod def run(self): pass
[docs]class BaseFLKnowledgeDistillationAPI: """Abstract class for API of federated learning with knowledge distillation. Args: server (aijack.collaborative.core.BaseServer): the server clients (List[aijack.collaborative.core.BaseClient]): a list of the clients public_dataloader (torch.utils.data.DataLoader): a dataloader for the public dataset local_dataloaders (List[torch.utils.data.DataLoader]): a list of local dataloaders validation_dataloader (torch.utils.data.DataLoader): a dataloader for the validation dataset criterion (function): a function to calculate the loss num_communication (int): the number of communication device (str): device type """ def __init__( self, server, clients, public_dataloader, local_dataloaders, validation_dataloader, criterion, num_communication, device, ): """Initialize BaseFLKnowledgeDistillationAPI""" self.server = server self.clients = clients self.public_dataloader = public_dataloader self.local_dataloaders = local_dataloaders self.validation_dataloader = validation_dataloader self.criterion = criterion self.num_communication = num_communication self.device = device self.client_num = len(clients)
[docs] def train_client(self, epoch=1, public=True): """Train local models with the local datasets or the public dataset. Args: public (bool, optional): Train with the public dataset or the local datasets. Defaults to True. Returns: List[float]: a list of average loss of each clients. """ loss_on_local_dataest = [] for client_idx in range(self.client_num): if public: trainloader = self.public_dataloader else: trainloader = self.local_dataloaders[client_idx] running_loss = self.clients[client_idx].local_train( epoch, self.criterion, trainloader, self.client_optimizers[client_idx] ) loss_on_local_dataest.append(copy.deepcopy(running_loss / len(trainloader))) return loss_on_local_dataest
[docs] @abstractmethod def run(self): pass
[docs] def score(self, dataloader, only_local=False): """Returns the performance on the given dataset. Args: dataloader (torch.utils.data.DataLoader): a dataloader only_local (bool): show only the local results Returns: Dict[str, int]: performance of global model and local models """ clients_score = [ accuracy_torch_dataloader(client, dataloader, device=self.device) for client in self.clients ] if only_local: return {"clients_score": clients_score} else: server_score = accuracy_torch_dataloader( self.server, dataloader, device=self.device ) return {"server_score": server_score, "clients_score": clients_score}
[docs] def local_score(self): """Returns the local performance of each clients. Returns: Dict[str, int]: performance of global model and local models """ local_score_list = [] for client, local_dataloader in zip(self.clients, self.local_dataloaders): temp_score = accuracy_torch_dataloader( client, local_dataloader, device=self.device ) local_score_list.append(temp_score) return {"clients_score": local_score_list}