Source code for aijack.collaborative.dsfl.server

import torch

from ...utils.metrics import crossentropyloss_between_logits
from ..core import BaseServer


[docs]class DSFLServer(BaseServer): """Server of DS-FL Args: clients (Llist[torch.nn.Module]): a list of clients. global_model (torch.nn.Module): the global model public_dataloader (torch.utils.data.DataLoader): a dataloader of the public dataset aggregation (str, optional): the type of the aggregation of the logits. Defaults to "ERA". distillation_loss_name (str, optional): the type of the loss function fot the distillation loss. Defaults to "crossentropy". era_temperature (float, optional): the temperature of ERA. Defaults to 0.1. server_id (int, optional): the id of this server. Defaults to 0. device (str, optional): device type. Defaults to "cpu". """ def __init__( self, clients, global_model, public_dataloader, aggregation="ERA", distillation_loss_name="crossentropy", era_temperature=0.1, server_id=0, device="cpu", ): """Init DSFLServer""" super(DSFLServer, self).__init__(clients, global_model, server_id=server_id) self.public_dataloader = public_dataloader self.aggregation = aggregation self.era_temperature = era_temperature self.consensus = None self.device = device self._set_distillation_loss(distillation_loss_name) def _set_distillation_loss(self, name): """Setup the loss function for distillation. `crossentropy`, `L2` or `L1`. Args: name (str): type of the function Raises: NotImplementedError: Raises when `name` is not supported. """ if name == "crossentropy": self.distillation_loss = crossentropyloss_between_logits elif name == "L2": self.distillation_loss = torch.nn.MSELoss() elif name == "L1": self.distillation_loss = torch.nn.L1Loss() else: raise NotImplementedError(f"{name} is not supported")
[docs] def action(self): self.update() self.distribute()
[docs] def update(self): """Update the aggregated consensus logits with the output logits received from the clients. Raises: NotImplementedError: Raises when the specified aggregation type is not supported. """ if self.aggregation == "ERA": self._entropy_reduction_aggregation() elif self.aggregation == "SA": self._simple_aggregation() else: raise NotImplementedError(f"{self.aggregation} is not supported")
[docs] def update_globalmodel(self, global_optimizer): """Train the global model with the global consensus logits. Args: global_optimizer (torch.optim.Optimizer): an optimizer Returns: float: average loss """ running_loss = 0 for global_data in self.public_dataloader: idx = global_data[0] x = global_data[1].to(self.device) y_global = self.consensus[idx, :].to(self.device) global_optimizer.zero_grad() y_pred = self(x) loss_consensus = self.distillation_loss(y_pred, y_global) loss_consensus.backward() global_optimizer.step() running_loss += loss_consensus.item() running_loss /= len(self.public_dataloader) return running_loss
[docs] def distribute(self): """Distribute the logits of public dataset to each client.""" for client in self.clients: client.download(self.consensus)
def _entropy_reduction_aggregation(self): """Aggregate the received logits with ERA""" self._simple_aggregation() self.consensus = torch.softmax(self.consensus / self.era_temperature, dim=1) def _simple_aggregation(self): """Aggregate the received logits with SA (calculating average)""" self.consensus = self.clients[0].upload() / len(self.clients) for client in self.clients[1:]: self.consensus += client.upload() / len(self.clients)