Source code for aijack.collaborative.optimizer.sgd

from .base import BaseFLOptimizer


[docs]class SGDFLOptimizer(BaseFLOptimizer): """Implementation of SGD to update the global model of Federated Learning. Args: parameters (List[torch.nn.Parameter]): parameters of the model lr (float, optional): learning rate. Defaults to 0.01. weight_decay (float, optional): coefficient of weight decay. Defaults to 0.0001. """ def __init__(self, parameters, lr=0.01, weight_decay=0.0000): super().__init__(parameters, lr=lr, weight_decay=weight_decay)
[docs] def step(self, grads): """Update the parameters with the give gradient Args: grads (List[torch.Tensor]): list of gradients """ for param, grad in zip(self.parameters, grads): if self.weight_decay == 0.0: param.data -= self.lr * grad else: param.data -= self.lr * (grad + self.weight_decay * param.data) self.t += 1