Source code for aijack.collaborative.optimizer.adam
import torch
from .base import BaseFLOptimizer
[docs]class AdamFLOptimizer(BaseFLOptimizer):
"""Implementation of Adam to update the global model of Federated Learning.
Args:
parameters (List[torch.nn.Parameter]): parameters of the model
lr (float, optional): learning rate. Defaults to 0.01.
weight_decay (float, optional): coefficient of weight decay. Defaults to 0.0001.
beta1 (float, optional): 1st-order exponential decay. Defaults to 0.9.
beta2 (float, optional): 2nd-order exponential decay. Defaults to 0.999.
epsilon (float, optional): a small value to prevent zero-devision. Defaults to 1e-8.
"""
def __init__(
self,
parameters,
lr=0.01,
weight_decay=0.0001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
):
super().__init__(parameters, lr=lr, weight_decay=weight_decay)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = [torch.zeros_like(param.data) for param in self.parameters]
self.v = [torch.zeros_like(param.data) for param in self.parameters]
[docs] def step(self, grads):
"""Update the parameters with the give gradient
Args:
grads (List[torch.Tensor]): list of gradients
"""
for i, (param, grad) in enumerate(zip(self.parameters, grads)):
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad * grad)
m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)
param.data -= self.lr * (
m_hat / torch.sqrt(v_hat) + self.weight_decay * param.data
)
self.t += 1