3.1.2. aijack.collaborative.dsfl package#

3.1.2.1. Submodules#

3.1.2.2. aijack.collaborative.dsfl.api module#

class aijack.collaborative.dsfl.api.DSFLAPI(server, clients, public_dataloader, local_dataloaders, criterion, num_communication, device, server_optimizer, client_optimizers, validation_dataloader=None, epoch_local_training=1, epoch_global_distillation=1, epoch_local_distillation=1, custom_action=<function DSFLAPI.<lambda>>)[source]#

Bases: aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI

API of DS-FL

Parameters
  • server (DSFLServer) – an instance of DSFLServer

  • clients (List[DSFLClient]) – a list of instances of DSFLClient

  • public_dataloader (torch.DataLoader) – a dataloader of public dataset

  • local_dataloaders (List[torch.DataLoader]) – a list of dataloaders of private dataests

  • validation_dataloader (torch.DataLoader) – a dataloader of validation dataset

  • criterion (function) – a loss function

  • num_communication (int) – number of communication

  • device (str) – device type

  • server_optimizer (torch.Optimizer) – a optimizer for the global model

  • client_optimizers ([torch.Optimizer]) – a list of optimizers for the local models

  • epoch_local_training (int, optional) – number of epochs of local training. Defaults to 1.

  • epoch_global_distillation (int, optional) – number of epochs of global distillation. Defaults to 1.

  • epoch_local_distillation (int, optional) – number of epochs of local distillation. Defaults to 1.

run()[source]#

3.1.2.3. aijack.collaborative.dsfl.client module#

class aijack.collaborative.dsfl.client.DSFLClient(model, public_dataloader, output_dim=1, round_decimal=None, consensus_scale=1.0, device='cpu', user_id=0)[source]#

Bases: aijack.collaborative.core.client.BaseClient

Client of DS-FL.

Parameters
  • model (torch.nn.Module) – _description_

  • public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.

  • output_dim (int, optional) – the dimension of the output. Defaults to 1.

  • round_decimal (int, optional) – number of digits to round up. Defaults to None.

  • device (str, optional) – device type. Defaults to “cpu”.

  • user_id (int, optional) – id of this client. Defaults to 0.

approach_consensus(consensus_optimizer)[source]#

Train the own local model to minimize the distance between the global logits and the output logits of the local model on the public dataset.

Parameters

consensus_optimizer (torch.optim.Optimizer) – an optimizer to train the local model.

Returns

averaged loss.

Return type

float

download(global_logit)[source]#

Download the global logits from the server.

Parameters

global_logit (torch.Tensor) – the global logits from the server

local_train(local_epoch, criterion, trainloader, optimizer)[source]#
upload()[source]#

Upload the output logits on the public dataset to the server.

Returns

the output logits of the public dataset.

Return type

torch.Tensor

3.1.2.4. aijack.collaborative.dsfl.server module#

class aijack.collaborative.dsfl.server.DSFLServer(clients, global_model, public_dataloader, aggregation='ERA', distillation_loss_name='crossentropy', era_temperature=0.1, server_id=0, device='cpu')[source]#

Bases: aijack.collaborative.core.server.BaseServer

Server of DS-FL

Parameters
  • clients (Llist[torch.nn.Module]) – a list of clients.

  • global_model (torch.nn.Module) – the global model

  • public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset

  • aggregation (str, optional) – the type of the aggregation of the logits. Defaults to “ERA”.

  • distillation_loss_name (str, optional) – the type of the loss function fot the distillation loss. Defaults to “crossentropy”.

  • era_temperature (float, optional) – the temperature of ERA. Defaults to 0.1.

  • server_id (int, optional) – the id of this server. Defaults to 0.

  • device (str, optional) – device type. Defaults to “cpu”.

action()[source]#

Execute thr routine of each communication.

distribute()[source]#

Distribute the logits of public dataset to each client.

update()[source]#

Update the aggregated consensus logits with the output logits received from the clients.

Raises

NotImplementedError – Raises when the specified aggregation type is not supported.

update_globalmodel(global_optimizer)[source]#

Train the global model with the global consensus logits.

Parameters

global_optimizer (torch.optim.Optimizer) – an optimizer

Returns

average loss

Return type

float

3.1.2.5. Module contents#

Implementation of DS-FL, Itahara, Sohei, et al. “Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data. ” arXiv preprint arXiv:2008.06180 (2020).