Source code for aijack.collaborative.fedavg.client

import copy

from ...manager import BaseManager
from ..core import BaseClient
from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG
from ..optimizer import AdamFLOptimizer, SGDFLOptimizer


[docs]class FedAVGClient(BaseClient): """Client of FedAVG for single process simulation Args: model (torch.nn.Module): local model user_id (int, optional): if of this client. Defaults to 0. lr (float, optional): learning rate. Defaults to 0.1. send_gradient (bool, optional): if True, communicate gradient to the server. otherwise, communicates model parameters. Defaults to True. optimizer_type_for_global_grad (str, optional): type of optimizer for model update with global gradient. sgd|adam. Defaults to "sgd". server_side_update (bool, optional): If True, the global model update is conducted in the server side. Defaults to True. optimizer_kwargs_for_global_grad (dict, optional): kwargs for the optimizer for global gradients. Defaults to {}. device (str, optional): device type. Defaults to "cpu". """ def __init__( self, model, user_id=0, lr=0.1, send_gradient=True, optimizer_type_for_global_grad="sgd", server_side_update=True, optimizer_kwargs_for_global_grad={}, device="cpu", ): super(FedAVGClient, self).__init__(model, user_id=user_id) self.lr = lr self.send_gradient = send_gradient self.server_side_update = server_side_update self.device = device if not self.server_side_update: self._setup_optimizer_for_global_grad( optimizer_type_for_global_grad, **optimizer_kwargs_for_global_grad ) self.prev_parameters = [] for param in self.model.parameters(): self.prev_parameters.append(copy.deepcopy(param)) self.initialized = False def _setup_optimizer_for_global_grad(self, optimizer_type, **kwargs): if optimizer_type == "sgd": self.optimizer_for_gloal_grad = SGDFLOptimizer( self.model.parameters(), lr=self.lr, **kwargs ) elif optimizer_type == "adam": self.optimizer_for_gloal_grad = AdamFLOptimizer( self.model.parameters(), lr=self.lr, **kwargs ) elif optimizer_type == "none": self.optimizer_for_gloal_grad = None else: raise NotImplementedError( f"{optimizer_type} is not supported. You can specify `sgd`, `adam`, or `none`." )
[docs] def upload(self): """Upload the current local model state""" if self.send_gradient: return self.upload_gradients() else: return self.upload_parameters()
[docs] def upload_parameters(self): """Upload the model parameters""" return self.model.state_dict()
[docs] def upload_gradients(self): """Upload the local gradients""" gradients = [] for param, prev_param in zip(self.model.parameters(), self.prev_parameters): gradients.append((prev_param - param) / self.lr) return gradients
[docs] def revert(self): """Revert the local model state to the previous global model""" for param, prev_param in zip(self.model.parameters(), self.prev_parameters): if param is not None: param = prev_param
[docs] def download(self, new_global_model): """Download the new global model""" if self.server_side_update or (not self.initialized): # receive the new global model as the model state self.model.load_state_dict(new_global_model) else: # receive the new global model as the global gradients self.revert() self.optimizer_for_gloal_grad.step(new_global_model) if not self.initialized: self.initialized = True self.prev_parameters = [] for param in self.model.parameters(): self.prev_parameters.append(copy.deepcopy(param))
[docs] def local_train( self, local_epoch, criterion, trainloader, optimizer, communication_id=0 ): loss_log = [] for _ in range(local_epoch): running_loss = 0.0 running_data_num = 0 for _, data in enumerate(trainloader, 0): inputs, labels = data inputs = inputs.to(self.device) labels = labels.to(self.device) optimizer.zero_grad() self.zero_grad() outputs = self(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() running_data_num += inputs.shape[0] loss_log.append(running_loss / running_data_num) return loss_log
[docs]def attach_mpi_to_fedavgclient(cls): class MPIFedAVGClientWrapper(cls): def __init__(self, comm, *args, **kwargs): super(MPIFedAVGClientWrapper, self).__init__(*args, **kwargs) self.comm = comm def action(self): self.upload() self.model.zero_grad() self.download() def upload(self): self.upload_gradient() def upload_gradient(self, destination_id=0): self.comm.send( super(MPIFedAVGClientWrapper, self).upload_gradients(), dest=destination_id, tag=GRADIENTS_TAG, ) def download(self): super(MPIFedAVGClientWrapper, self).download( self.comm.recv(tag=PARAMETERS_TAG) ) def mpi_initialize(self): self.download() return MPIFedAVGClientWrapper
[docs]class MPIFedAVGClientManager(BaseManager):
[docs] def attach(self, cls): return attach_mpi_to_fedavgclient(cls, *self.args, **self.kwargs)