Source code for aijack.collaborative.dsfl.client
from ...utils.metrics import crossentropyloss_between_logits
from ...utils.utils import default_local_train_for_client, torch_round_x_decimal
from ..core import BaseClient
from ..fedmd.client import initialize_global_logit
[docs]class DSFLClient(BaseClient):
"""Client of DS-FL.
Args:
model (torch.nn.Module): _description_
public_dataloader (torch.utils.data.DataLoader): a dataloader of the public dataset.
output_dim (int, optional): the dimension of the output. Defaults to 1.
round_decimal (int, optional): number of digits to round up. Defaults to None.
device (str, optional): device type. Defaults to "cpu".
user_id (int, optional): id of this client. Defaults to 0.
"""
def __init__(
self,
model,
public_dataloader,
output_dim=1,
round_decimal=None,
consensus_scale=1.0,
device="cpu",
user_id=0,
):
"""Init DSFLClient."""
super().__init__(model, user_id)
self.public_dataloader = public_dataloader
self.round_decimal = round_decimal
self.device = device
self.global_logit = None
self.consensus_scale = consensus_scale
len_public_dataloader = len(self.public_dataloader.dataset)
self.logit2server = initialize_global_logit(
len_public_dataloader, output_dim, self.device
)
[docs] def upload(self):
"""Upload the output logits on the public dataset to the server.
Returns:
torch.Tensor: the output logits of the public dataset.
"""
for data in self.public_dataloader:
idx = data[0]
x = data[1]
x = x.to(self.device)
self.logit2server[idx, :] = self(x).detach().softmax(dim=-1)
if self.round_decimal is None:
return self.logit2server
else:
return torch_round_x_decimal(self.logit2server, self.round_decimal)
[docs] def download(self, global_logit):
"""Download the global logits from the server.
Args:
global_logit (torch.Tensor): the global logits from the server
"""
self.global_logit = global_logit
[docs] def local_train(self, local_epoch, criterion, trainloader, optimizer):
return default_local_train_for_client(
self, local_epoch, criterion, trainloader, optimizer
)
[docs] def approach_consensus(self, consensus_optimizer):
"""Train the own local model to minimize the distance between the global logits and
the output logits of the local model on the public dataset.
Args:
consensus_optimizer (torch.optim.Optimizer): an optimizer to train the local model.
Returns:
float: averaged loss.
"""
running_loss = 0
for global_data in self.public_dataloader:
idx = global_data[0]
x = global_data[1].to(self.device)
y_global = self.global_logit[idx, :].to(self.device).detach()
consensus_optimizer.zero_grad()
y_local = self(x)
loss_consensus = self.consensus_scale * crossentropyloss_between_logits(
y_local, y_global
)
loss_consensus.backward()
consensus_optimizer.step()
running_loss += loss_consensus.item()
running_loss /= len(self.public_dataloader)
return running_loss