3.1.2. aijack.collaborative.dsfl package#
3.1.2.1. Submodules#
3.1.2.2. aijack.collaborative.dsfl.api module#
- class aijack.collaborative.dsfl.api.DSFLAPI(server, clients, public_dataloader, local_dataloaders, criterion, num_communication, device, server_optimizer, client_optimizers, validation_dataloader=None, epoch_local_training=1, epoch_global_distillation=1, epoch_local_distillation=1, custom_action=<function DSFLAPI.<lambda>>)[source]#
Bases:
aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI
API of DS-FL
- Parameters
server (DSFLServer) – an instance of DSFLServer
clients (List[DSFLClient]) – a list of instances of DSFLClient
public_dataloader (torch.DataLoader) – a dataloader of public dataset
local_dataloaders (List[torch.DataLoader]) – a list of dataloaders of private dataests
validation_dataloader (torch.DataLoader) – a dataloader of validation dataset
criterion (function) – a loss function
num_communication (int) – number of communication
device (str) – device type
server_optimizer (torch.Optimizer) – a optimizer for the global model
client_optimizers ([torch.Optimizer]) – a list of optimizers for the local models
epoch_local_training (int, optional) – number of epochs of local training. Defaults to 1.
epoch_global_distillation (int, optional) – number of epochs of global distillation. Defaults to 1.
epoch_local_distillation (int, optional) – number of epochs of local distillation. Defaults to 1.
3.1.2.3. aijack.collaborative.dsfl.client module#
- class aijack.collaborative.dsfl.client.DSFLClient(model, public_dataloader, output_dim=1, round_decimal=None, consensus_scale=1.0, device='cpu', user_id=0)[source]#
Bases:
aijack.collaborative.core.client.BaseClient
Client of DS-FL.
- Parameters
model (torch.nn.Module) – _description_
public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.
output_dim (int, optional) – the dimension of the output. Defaults to 1.
round_decimal (int, optional) – number of digits to round up. Defaults to None.
device (str, optional) – device type. Defaults to “cpu”.
user_id (int, optional) – id of this client. Defaults to 0.
- approach_consensus(consensus_optimizer)[source]#
Train the own local model to minimize the distance between the global logits and the output logits of the local model on the public dataset.
- Parameters
consensus_optimizer (torch.optim.Optimizer) – an optimizer to train the local model.
- Returns
averaged loss.
- Return type
float
3.1.2.4. aijack.collaborative.dsfl.server module#
- class aijack.collaborative.dsfl.server.DSFLServer(clients, global_model, public_dataloader, aggregation='ERA', distillation_loss_name='crossentropy', era_temperature=0.1, server_id=0, device='cpu')[source]#
Bases:
aijack.collaborative.core.server.BaseServer
Server of DS-FL
- Parameters
clients (Llist[torch.nn.Module]) – a list of clients.
global_model (torch.nn.Module) – the global model
public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset
aggregation (str, optional) – the type of the aggregation of the logits. Defaults to “ERA”.
distillation_loss_name (str, optional) – the type of the loss function fot the distillation loss. Defaults to “crossentropy”.
era_temperature (float, optional) – the temperature of ERA. Defaults to 0.1.
server_id (int, optional) – the id of this server. Defaults to 0.
device (str, optional) – device type. Defaults to “cpu”.
3.1.2.5. Module contents#
Implementation of DS-FL, Itahara, Sohei, et al. “Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data. ” arXiv preprint arXiv:2008.06180 (2020).