Source code for aijack.collaborative.fedprox.api

import copy

from ..fedavg import FedAVGAPI, MPIFedAVGAPI


[docs]class FedProxAPI(FedAVGAPI): """Implementation of FedProx (https://arxiv.org/abs/1812.06127)""" def __init__(self, *args, mu=0.01, **kwargs): super().__init__(*args, **kwargs) self.mu = mu
[docs] def local_train(self, i): for client_idx in range(self.client_num): self.clients[client_idx].local_train( self.server.parameters(), self.local_epoch, self.criterion, self.local_dataloaders[client_idx], self.local_optimizers[client_idx], communication_id=i, )
[docs] def run(self): for i in range(self.num_communication): self.local_train(i) self.server.receive(use_gradients=self.use_gradients) if self.use_gradients: self.server.updata_from_gradients(weight=self.clients_weight) else: self.server.update_from_parameters(weight=self.clients_weight) self.custom_action(self)
[docs]class MPIFedProxAPI(MPIFedAVGAPI): def __init__(self, *args, mu=0.01, **kwargs): super().__init__(*args, **kwargs) self.mu = mu
[docs] def local_train(self, com_cnt): self.party.prev_parameters = [] for param in self.party.parameters(): self.party.prev_parameters.append(copy.deepcopy(param)) self.party.local_train( self.party.prev_parameters, self.local_epoch, self.criterion, self.local_dataloader, self.local_optimizer, communication_id=com_cnt, )