Source code for aijack.collaborative.fedgems.client

import torch
from torch import nn

from ...utils.utils import default_local_train_for_client, torch_round_x_decimal
from ..core import BaseClient


[docs]class FedGEMSClient(BaseClient): def __init__( self, model, user_id=0, lr=0.1, base_loss_func=nn.CrossEntropyLoss(), kldiv_loss_func=nn.KLDivLoss(), epsilon=0.75, round_decimal=None, device="cpu", ): super(FedGEMSClient, self).__init__(model, user_id=user_id) self.lr = lr self.predicted_values_of_server = None self.base_loss_func = base_loss_func self.kldiv_loss_func = kldiv_loss_func self.epsilon = epsilon self.round_decimal = round_decimal self.device = device
[docs] def upload(self, x): result = self(x) if self.round_decimal is None: return result else: return torch_round_x_decimal(result, self.round_decimal)
[docs] def download(self, predicted_values_of_server): self.predicted_values_of_server = predicted_values_of_server
[docs] def local_train(self, local_epoch, criterion, trainloader, optimizer): return default_local_train_for_client( self, local_epoch, criterion, trainloader, optimizer )
[docs] def calc_loss_on_public_dataset(self, idx, y_pred, y): y_pred_server = self.predicted_values_of_server[idx] base_loss = self.epsilon * self.base_loss_func(y_pred, y.to(torch.int64)) kl_loss = (1 - self.epsilon) * self.kldiv_loss_func( y_pred_server.softmax(dim=-1).log(), y_pred.softmax(dim=-1) ) return base_loss + kl_loss