Source code for aijack.collaborative.fedexp.server
import torch
from ..fedavg import FedAVGServer
[docs]class FedEXPServer(FedAVGServer):
"""Implementation of 'Jhunjhunwala, Divyansh, Shiqiang Wang, and Gauri Joshi. "FedExP: Speeding up Federated Averaging Via Extrapolation." arXiv preprint arXiv:2301.09604 (2023).'"""
def __init__(self, *args, eps=1e-5, **kwargs):
super(FedEXPServer, self).__init__(*args, **kwargs)
self.eps = eps
[docs] def update(self, *args, **kwargs):
self.update_from_gradients()
[docs] def update_from_gradients(self):
self.aggregated_gradients = [
torch.zeros_like(params) for params in self.server_model.parameters()
]
grad_norms = []
M = len(self.uploaded_gradients)
len_gradients = len(self.aggregated_gradients)
for gradients in self.uploaded_gradients:
for gradient_id in range(len_gradients):
self.aggregated_gradients[gradient_id] = (
gradients[gradient_id] * (1 / M)
+ self.aggregated_gradients[gradient_id]
)
grad_norms.append(
sum([torch.linalg.norm(g) ** 2 for g in gradients[gradient_id]])
)
agg_grad_norm = sum(
[torch.linalg.norm(g) ** 2 for g in self.aggregated_gradients]
)
self.optimizer.lr = max(
1, sum([g / (2 * M * (agg_grad_norm + self.eps)) for g in grad_norms])
)
self.optimizer.step(self.aggregated_gradients)