Source code for aijack.collaborative.fedprox.client
import torch
from ..fedavg import FedAVGClient
[docs]class FedProxClient(FedAVGClient):
[docs] def local_train(
self,
server_parameters,
local_epoch,
criterion,
trainloader,
optimizer,
communication_id=0,
):
loss_log = []
for _ in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
inputs.requires_grad = True
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
loss.backward()
for local_param, global_param in zip(
self.parameters(), server_parameters
):
loss += (
self.mu / 2 * torch.norm(local_param.data - global_param.data)
)
local_param.grad.data += self.mu * (
local_param.data - global_param.data
)
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
loss_log.append(running_loss / running_data_num)
return loss_log