Source code for aijack.collaborative.fedavg.api
import copy
from ..core.api import BaseFedAPI
[docs]class FedAVGAPI(BaseFedAPI):
"""Implementation of FedAVG (McMahan, Brendan, et al. 'Communication-efficient learning of deep networks from decentralized data.' Artificial intelligence and statistics. PMLR, 2017.)
Args:
server (FedAvgServer): FedAVG server.
clients ([FedAvgClient]): a list of FedAVG clients.
criterion (function): loss function.
local_optimizers ([torch.optimizer]): a list of local optimizers for clients
local_dataloaders ([toch.dataloader]): a list of local dataloaders for clients
num_communication (int, optional): number of communication. Defaults to 1.
local_epoch (int, optional): number of epochs for local training within each communication. Defaults to 1.
use_gradients (bool, optional): communicate gradients if True. Otherwise communicate parameters. Defaults to True.
custom_action (function, optional): arbitrary function that takes this instance itself. Defaults to lambdax:x.
device (str, optional): device type. Defaults to "cpu".
"""
def __init__(
self,
server,
clients,
criterion,
local_optimizers,
local_dataloaders,
num_communication=1,
local_epoch=1,
use_gradients=True,
custom_action=lambda x: x,
device="cpu",
):
self.server = server
self.clients = clients
self.criterion = criterion
self.local_optimizers = local_optimizers
self.local_dataloaders = local_dataloaders
self.num_communication = num_communication
self.local_epoch = local_epoch
self.use_gradients = use_gradients
self.custom_action = custom_action
self.device = device
self.client_num = len(self.clients)
local_dataset_sizes = [
len(dataloader.dataset) for dataloader in self.local_dataloaders
]
sum_local_dataset_sizes = sum(local_dataset_sizes)
self.server.weight = [
dataset_size / sum_local_dataset_sizes
for dataset_size in local_dataset_sizes
]
self.logging = {}
[docs] def local_train(self, i):
self.logging[i] = {}
for client_idx in range(self.client_num):
loss_log = self.clients[client_idx].local_train(
self.local_epoch,
self.criterion,
self.local_dataloaders[client_idx],
self.local_optimizers[client_idx],
communication_id=i,
)
self.logging[i][client_idx] = loss_log
[docs] def run(self):
self.server.force_send_model_state_dict = True
self.server.distribute()
self.server.force_send_model_state_dict = False
for i in range(self.num_communication):
self.local_train(i)
self.server.receive(use_gradients=self.use_gradients)
if self.use_gradients:
self.server.update_from_gradients()
else:
self.server.update_from_parameters()
self.server.distribute()
self.custom_action(self)
[docs]class MPIFedAVGAPI(BaseFedAPI):
def __init__(
self,
comm,
party,
is_server,
criterion,
local_optimizer=None,
local_dataloader=None,
num_communication=1,
local_epoch=1,
custom_action=lambda x: x,
device="cpu",
):
self.comm = comm
self.party = party
self.is_server = is_server
self.criterion = criterion
self.local_optimizer = local_optimizer
self.local_dataloader = local_dataloader
self.num_communication = num_communication
self.local_epoch = local_epoch
self.custom_action = custom_action
self.device = device
self.logging = []
[docs] def run(self):
self.party.mpi_initialize()
self.comm.Barrier()
for i in range(self.num_communication):
if not self.is_server:
loss_logging = self.local_train(i)
self.logging.append(loss_logging)
self.party.action()
self.custom_action(self)
self.comm.Barrier()
[docs] def local_train(self, com_cnt):
self.party.prev_parameters = []
for param in self.party.model.parameters():
self.party.prev_parameters.append(copy.deepcopy(param))
loss_logging = self.party.local_train(
self.local_epoch,
self.criterion,
self.local_dataloader,
self.local_optimizer,
communication_id=com_cnt,
)
return loss_logging