import numpy as np
import torch
from ...manager import BaseManager
from ..core import BaseServer
from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG
from ..optimizer import AdamFLOptimizer, SGDFLOptimizer
[docs]class FedAVGServer(BaseServer):
"""Server of FedAVG for single process simulation
Args:
clients ([FedAvgClient]): a list of FedAVG clients.
global_model (torch.nn.Module): global model.
server_id (int, optional): id of this server. Defaults to 0.
lr (float, optional): learning rate. Defaults to 0.1.
optimizer_type (str, optional): optimizer for the update of global model . Defaults to "sgd".
server_side_update (bool, optional): If True, update the global model at the server-side. Defaults to True.
optimizer_kwargs (dict, optional): kwargs for the global optimizer. Defaults to {}.
"""
def __init__(
self,
clients,
global_model,
server_id=0,
lr=0.1,
optimizer_type="sgd",
server_side_update=True,
optimizer_kwargs={},
device="cpu",
):
super(FedAVGServer, self).__init__(clients, global_model, server_id=server_id)
self.lr = lr
self._setup_optimizer(optimizer_type, **optimizer_kwargs)
self.server_side_update = server_side_update
self.device = device
self.uploaded_gradients = []
self.force_send_model_state_dict = True
self.weight = np.ones(self.num_clients) / self.num_clients
def _setup_optimizer(self, optimizer_type, **kwargs):
if optimizer_type == "sgd":
self.optimizer = SGDFLOptimizer(
self.server_model.parameters(), lr=self.lr, **kwargs
)
elif optimizer_type == "adam":
self.optimizer = AdamFLOptimizer(
self.server_model.parameters(), lr=self.lr, **kwargs
)
elif optimizer_type == "none":
self.optimizer = None
else:
raise NotImplementedError(
f"{optimizer_type} is not supported. You can specify `sgd`, `adam`, or `none`."
)
[docs] def action(self, use_gradients=True):
self.receive(use_gradients)
self.update(use_gradients)
self.distribute()
[docs] def receive(self, use_gradients=True):
"""Receive the local models
Args:
use_gradients (bool, optional): If True, receive the local gradients. Otherwise, receive the local parameters. Defaults to True.
"""
if use_gradients:
self.receive_local_gradients()
else:
self.receive_local_parameters()
[docs] def update(self, use_gradients=True):
"""Update the global model
Args:
use_gradients (bool, optional): If True, update the global model with aggregated local gradients. Defaults to True.
"""
if use_gradients:
self.update_from_gradients()
else:
self.update_from_parameters()
def _preprocess_local_gradients(self, uploaded_grad):
return uploaded_grad
[docs] def receive_local_gradients(self):
"""Receive local gradients"""
self.uploaded_gradients = [
self._preprocess_local_gradients(c.upload_gradients()) for c in self.clients
]
[docs] def receive_local_parameters(self):
"""Receive local parameters"""
self.uploaded_parameters = [c.upload_parameters() for c in self.clients]
[docs] def update_from_gradients(self):
"""Update the global model with the local gradients."""
self.aggregated_gradients = [
torch.zeros_like(params) for params in self.server_model.parameters()
]
len_gradients = len(self.aggregated_gradients)
for i, gradients in enumerate(self.uploaded_gradients):
for gradient_id in range(len_gradients):
self.aggregated_gradients[gradient_id] = (
gradients[gradient_id] * self.weight[i]
+ self.aggregated_gradients[gradient_id]
)
if self.server_side_update:
self.optimizer.step(self.aggregated_gradients)
[docs] def update_from_parameters(self):
"""Update the global model with the local model parameters."""
averaged_params = self.uploaded_parameters[0]
for k in averaged_params.keys():
for i in range(0, len(self.uploaded_parameters)):
local_model_params = self.uploaded_parameters[i]
w = self.weight[i]
if i == 0:
averaged_params[k] = local_model_params[k] * w
else:
averaged_params[k] += local_model_params[k] * w
self.server_model.load_state_dict(averaged_params)
[docs] def distribute(self):
"""Distribute the current global model to each client.
Args:
force_send_model_state_dict (bool, optional): If True, send the global model as the dictionary of model state regardless of other parameters. Defaults to False.
"""
for client in self.clients:
if type(client) != int:
if self.server_side_update or self.force_send_model_state_dict:
client.download(self.server_model.state_dict())
else:
client.download(self.aggregated_gradients)
[docs]def attach_mpi_to_fedavgserver(cls):
class MPIFedAVGServerWrapper(cls):
"""MPI Wrapper for FedAVG-based Server"""
def __init__(self, comm, *args, **kwargs):
self.comm = comm
super(MPIFedAVGServerWrapper, self).__init__(*args, **kwargs)
self.num_clients = len(self.clients)
self.round = 0
def action(self):
self.receive()
self.update()
self.distribute()
self.round += 1
def receive(self):
self.receive_local_gradients()
def receive_local_gradients(self):
self.uploaded_gradients = []
while len(self.uploaded_gradients) < self.num_clients:
gradients_received = self.comm.recv(tag=GRADIENTS_TAG)
self.uploaded_gradients.append(
self._preprocess_local_gradients(gradients_received)
)
def distribute(self):
for client_id in self.clients:
self.comm.send(
self.server_model.state_dict(),
dest=client_id,
tag=PARAMETERS_TAG,
)
def mpi_initialize(self):
self.distribute()
return MPIFedAVGServerWrapper
[docs]class MPIFedAVGServerManager(BaseManager):
[docs] def attach(self, cls):
return attach_mpi_to_fedavgserver(cls, *self.args, **self.kwargs)