Source code for aijack.collaborative.optimizer.base
from abc import abstractmethod
[docs]class BaseFLOptimizer:
"""Basic class for optimizers of the server in Federated Learning.
Args:
parameters (List[torch.nn.Parameter]): parameters of the model
lr (float, optional): learning rate. Defaults to 0.01.
weight_decay (float, optional): coefficient of weight decay. Defaults to 0.0001.
"""
def __init__(self, parameters, lr=0.01, weight_decay=0.0001):
self.parameters = list(parameters)
self.lr = lr
self.weight_decay = weight_decay
self.t = 1
[docs] @abstractmethod
def step(self, grads):
"""Update the parameters with the give gradient
Args:
grads (List[torch.Tensor]): list of gradients
"""
pass