Source code for aijack.collaborative.fedmd.client
import torch
from torch import nn
from ...manager import BaseManager
from ...utils.utils import default_local_train_for_client, torch_round_x_decimal
from ..core import BaseClient
from ..core.utils import GLOBAL_LOGIT_TAG, LOCAL_LOGIT_TAG
[docs]def initialize_global_logit(len_public_dataloader, output_dim, device):
return torch.ones((len_public_dataloader, output_dim)).to(device) * float("inf")
[docs]class FedMDClient(BaseClient):
def __init__(
self,
model,
public_dataloader,
output_dim=1,
batch_size=8,
user_id=0,
base_loss_func=nn.CrossEntropyLoss(),
consensus_loss_func=nn.L1Loss(),
round_decimal=None,
device="cpu",
):
super(FedMDClient, self).__init__(model, user_id=user_id)
self.public_dataloader = public_dataloader
self.batch_size = batch_size
self.base_loss_func = base_loss_func
self.consensus_loss_func = consensus_loss_func
self.round_decimal = round_decimal
self.device = device
self.predicted_values_of_server = None
len_public_dataloader = len(self.public_dataloader.dataset)
self.logit2server = initialize_global_logit(
len_public_dataloader, output_dim, self.device
)
[docs] def upload(self):
for data in self.public_dataloader:
idx = data[0]
x = data[1]
x = x.to(self.device)
self.logit2server[idx, :] = self(x).detach()
if self.round_decimal is None:
return self.logit2server
else:
return torch_round_x_decimal(self.logit2server, self.round_decimal)
[docs] def download(self, predicted_values_of_server):
self.predicted_values_of_server = predicted_values_of_server
[docs] def local_train(self, local_epoch, criterion, trainloader, optimizer):
return default_local_train_for_client(
self, local_epoch, criterion, trainloader, optimizer
)
[docs] def approach_consensus(self, consensus_optimizer):
running_loss = 0
for data in self.public_dataloader:
idx = data[0]
x = data[1].to(self.device)
y_consensus = self.predicted_values_of_server[idx, :].to(self.device)
consensus_optimizer.zero_grad()
y_pred = self(x)
loss = self.consensus_loss_func(y_pred, y_consensus)
loss.backward()
consensus_optimizer.step()
running_loss += loss.item()
return running_loss
[docs]def attach_mpi_to_fedmdclient(cls):
class MPIFedMDClientWrapper(cls):
def __init__(self, comm, *args, **kwargs):
super(MPIFedMDClientWrapper, self).__init__(*args, **kwargs)
self.comm = comm
def action(self):
self.mpi_upload()
self.model.zero_grad()
self.mpi_download()
def mpi_upload(self):
self.mpi_upload_logits()
def mpi_upload_logits(self, destination_id=0):
self.comm.send(self.upload(), dest=destination_id, tag=LOCAL_LOGIT_TAG)
def mpi_download(self):
self.download(self.comm.recv(tag=GLOBAL_LOGIT_TAG))
def mpi_initialize(self):
self.mpi_download()
return MPIFedMDClientWrapper
[docs]class MPIFedMDClientManager(BaseManager):
[docs] def attach(self, cls):
return attach_mpi_to_fedmdclient(cls, *self.args, **self.kwargs)