Source code for aijack.collaborative.dsfl.api

import copy

from ..core.api import BaseFLKnowledgeDistillationAPI


[docs]class DSFLAPI(BaseFLKnowledgeDistillationAPI): """API of DS-FL Args: server (DSFLServer): an instance of DSFLServer clients (List[DSFLClient]): a list of instances of DSFLClient public_dataloader (torch.DataLoader): a dataloader of public dataset local_dataloaders (List[torch.DataLoader]): a list of dataloaders of private dataests validation_dataloader (torch.DataLoader): a dataloader of validation dataset criterion (function): a loss function num_communication (int): number of communication device (str): device type server_optimizer (torch.Optimizer): a optimizer for the global model client_optimizers ([torch.Optimizer]): a list of optimizers for the local models epoch_local_training (int, optional): number of epochs of local training. Defaults to 1. epoch_global_distillation (int, optional): number of epochs of global distillation. Defaults to 1. epoch_local_distillation (int, optional): number of epochs of local distillation. Defaults to 1. """ def __init__( self, server, clients, public_dataloader, local_dataloaders, criterion, num_communication, device, server_optimizer, client_optimizers, validation_dataloader=None, epoch_local_training=1, epoch_global_distillation=1, epoch_local_distillation=1, custom_action=lambda x: x, ): """Init DSFLAPI""" super().__init__( server, clients, public_dataloader, local_dataloaders, validation_dataloader, criterion, num_communication, device, ) self.server_optimizer = server_optimizer self.client_optimizers = client_optimizers self.epoch_local_training = epoch_local_training self.epoch_global_distillation = epoch_global_distillation self.epoch_local_distillation = epoch_local_distillation self.custom_action = custom_action self.epoch = 0
[docs] def run(self): logging = { "loss_local": [], "loss_client_consensus": [], "loss_server_consensus": [], "acc_local": [], "acc_val": [], } for i in range(1, self.num_communication + 1): self.epoch = i for _ in range(self.epoch_local_training): loss_local = self.train_client(public=False) logging["loss_local"].append(loss_local) self.server.action() acc_on_local_dataset = self.local_score() print(f"epoch={i} acc on local datasets: ", acc_on_local_dataset) logging["acc_local"].append(acc_on_local_dataset) # distillation for _ in range(self.epoch_global_distillation): loss_global = self.server.update_globalmodel(self.server_optimizer) logging["loss_server_consensus"].append(loss_global) self.custom_action(self) temp_consensus_loss = [] if len(self.clients) > 1: for j, client in enumerate(self.clients): for _ in range(self.epoch_local_distillation): consensus_loss = client.approach_consensus( self.client_optimizers[j] ) temp_consensus_loss.append(consensus_loss) logging["loss_client_consensus"].append(temp_consensus_loss) print(f"epoch {i}: loss_local", loss_local) print(f"epoch {i}: loss_client_consensus", temp_consensus_loss) print(f"epoch {i}: loss_server_consensus", loss_global) # validation if self.validation_dataloader is not None: acc_val = self.score(self.validation_dataloader) print(f"epoch={i} acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) return logging