Source code for aijack.collaborative.moon.client
import copy
import torch
import torch.nn.functional as F
from ..fedavg import FedAVGClient
[docs]class MOONClient(FedAVGClient):
"""Client of MOON for single process simulation
(Li, Qinbin, Bingsheng He, and Dawn Song. "Model-contrastive
federated learning." Proceedings of the IEEE/CVF conference
on computer vision and pattern recognition. 2021.)
Args:
model (torch.nn.Module): local model
mu (float): weight of model-contrastive loss
tau (float): tempreature within model-contrastive loss
"""
def __init__(
self,
model,
mu=0.1,
tau=1.0,
**kwargs,
):
super(MOONClient, self).__init__(model, **kwargs)
self.mu = mu
self.tau = tau
self.global_model = copy.deepcopy(model)
self.prev_model = copy.deepcopy(model)
[docs] def local_train(
self,
local_epoch,
criterion,
trainloader,
optimizer,
communication_id=0,
):
if communication_id != 0:
for param, glob_param in zip(
self.global_model.parameters(), self.model.parameters()
):
if param is not None:
param = glob_param
for param, prev_param in zip(
self.prev_model.parameters(), self.prev_parameters
):
if param is not None:
param = prev_param
loss_log = []
for _ in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
if communication_id != 0:
glob_outputs = self.global_model(inputs)
prev_outputs = self.prev_model(inputs)
exp_sim_cg = torch.exp(
F.cosine_similarity(outputs, glob_outputs) / self.tau
)
exp_sim_cp = torch.exp(
F.cosine_similarity(outputs, prev_outputs) / self.tau
)
loss_con = -1 * torch.log(exp_sim_cg / (exp_sim_cg + exp_sim_cp))
loss = loss + self.mu * loss_con
loss.backward()
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
loss_log.append(running_loss / running_data_num)
return loss_log