3.1.5. aijack.collaborative.fedgems package#

3.1.5.1. Submodules#

3.1.5.2. aijack.collaborative.fedgems.api module#

class aijack.collaborative.fedgems.api.FedGEMSAPI(server, clients, public_dataloader, local_dataloaders, criterion, server_optimizer, client_optimizers, validation_dataloader=None, num_communication=10, epoch_client_on_localdataset=10, epoch_client_on_publicdataset=10, epoch_server_on_publicdataset=10, device='cpu', custom_action=<function FedGEMSAPI.<lambda>>)[source]#

Bases: aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI

API of FedGEMSAPI.

Parameters
  • server (FedGEMSServer) – a server.

  • clients (List[FedGEMSClient]) – a list of clients.

  • public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.

  • local_dataloaders (List[torch.utils.data.DataLoader]) – a list of dataloaders of the local datasets.

  • validation_dataloader (torch.utils.data.DataLoader) – a dataloader of the validation dataset.

  • criterion (function)) – a loss function

  • server_optimizer (torch.optim.Optimizer) – an optimizer for the global model

  • client_optimizers (List[torch.optim.Optimizer]) – a list of optimizers for the local models

  • num_communication (int, optional) – the number of communications. Defaults to 10.

  • epoch_client_on_localdataset (int, optional) – the number of epochs of client-side training on the private datasets. Defaults to 10.

  • epoch_client_on_publicdataset (int, optional) – the number of epochs of client-side training on the public datasets. Defaults to 10.

  • epoch_server_on_publicdataset (int, optional) – the number of epochs of server-side training on the public dataset. Defaults to 10.

  • device (str, optional) – device type. Defaults to “cpu”.

  • custom_action (function, optional) – custom function which this api calls at the end of every communication. Defaults to lambda x:x.

run()[source]#
train_client_on_public_dataset()[source]#

Train clients on the public dataset.

Returns

a list of average loss of each client.

Return type

List[float]

train_server_on_public_dataset()[source]#

Train the global model on the public dataset.

Returns

average loss

Return type

float

3.1.5.3. aijack.collaborative.fedgems.client module#

class aijack.collaborative.fedgems.client.FedGEMSClient(model, user_id=0, lr=0.1, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), epsilon=0.75, round_decimal=None, device='cpu')[source]#

Bases: aijack.collaborative.core.client.BaseClient

calc_loss_on_public_dataset(idx, y_pred, y)[source]#
download(predicted_values_of_server)[source]#

Download the global model from the server.

local_train(local_epoch, criterion, trainloader, optimizer)[source]#
upload(x)[source]#

Upload the locally learned informatino to the server.

3.1.5.4. aijack.collaborative.fedgems.server module#

class aijack.collaborative.fedgems.server.FedGEMSServer(clients, global_model, len_public_dataloader, output_dim=1, self_evaluation_func=None, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), server_id=0, lr=0.1, epsilon=0.75, device='cpu')[source]#

Bases: aijack.collaborative.core.server.BaseServer

action()[source]#

Execute thr routine of each communication.

distribute()[source]#

Distribute the logits of public dataset to each client.

self_evaluation_on_the_public_dataset(idxs, x, y)[source]#

Execute self evaluation on the public dataset

Parameters
  • idxs (torch.Tensor) – indexs of x

  • x (torch.Tensor) – input data

  • y (torch.Tensr) – labels of x

Returns

the loss

update(idxs, x)[source]#

Register the predicted logits to self.predicted_values

3.1.5.5. Module contents#

Implementation of Cheng, Sijie, et al. “FedGEMS: Federated Learning of Larger Server Models via Selective Knowledge Fusion.” arXiv preprint arXiv:2110.11027 (2021).

class aijack.collaborative.fedgems.FedGEMSAPI(server, clients, public_dataloader, local_dataloaders, criterion, server_optimizer, client_optimizers, validation_dataloader=None, num_communication=10, epoch_client_on_localdataset=10, epoch_client_on_publicdataset=10, epoch_server_on_publicdataset=10, device='cpu', custom_action=<function FedGEMSAPI.<lambda>>)[source]#

Bases: aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI

API of FedGEMSAPI.

Parameters
  • server (FedGEMSServer) – a server.

  • clients (List[FedGEMSClient]) – a list of clients.

  • public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.

  • local_dataloaders (List[torch.utils.data.DataLoader]) – a list of dataloaders of the local datasets.

  • validation_dataloader (torch.utils.data.DataLoader) – a dataloader of the validation dataset.

  • criterion (function)) – a loss function

  • server_optimizer (torch.optim.Optimizer) – an optimizer for the global model

  • client_optimizers (List[torch.optim.Optimizer]) – a list of optimizers for the local models

  • num_communication (int, optional) – the number of communications. Defaults to 10.

  • epoch_client_on_localdataset (int, optional) – the number of epochs of client-side training on the private datasets. Defaults to 10.

  • epoch_client_on_publicdataset (int, optional) – the number of epochs of client-side training on the public datasets. Defaults to 10.

  • epoch_server_on_publicdataset (int, optional) – the number of epochs of server-side training on the public dataset. Defaults to 10.

  • device (str, optional) – device type. Defaults to “cpu”.

  • custom_action (function, optional) – custom function which this api calls at the end of every communication. Defaults to lambda x:x.

run()[source]#
train_client_on_public_dataset()[source]#

Train clients on the public dataset.

Returns

a list of average loss of each client.

Return type

List[float]

train_server_on_public_dataset()[source]#

Train the global model on the public dataset.

Returns

average loss

Return type

float

class aijack.collaborative.fedgems.FedGEMSClient(model, user_id=0, lr=0.1, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), epsilon=0.75, round_decimal=None, device='cpu')[source]#

Bases: aijack.collaborative.core.client.BaseClient

calc_loss_on_public_dataset(idx, y_pred, y)[source]#
download(predicted_values_of_server)[source]#

Download the global model from the server.

local_train(local_epoch, criterion, trainloader, optimizer)[source]#
upload(x)[source]#

Upload the locally learned informatino to the server.

class aijack.collaborative.fedgems.FedGEMSServer(clients, global_model, len_public_dataloader, output_dim=1, self_evaluation_func=None, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), server_id=0, lr=0.1, epsilon=0.75, device='cpu')[source]#

Bases: aijack.collaborative.core.server.BaseServer

action()[source]#

Execute thr routine of each communication.

distribute()[source]#

Distribute the logits of public dataset to each client.

self_evaluation_on_the_public_dataset(idxs, x, y)[source]#

Execute self evaluation on the public dataset

Parameters
  • idxs (torch.Tensor) – indexs of x

  • x (torch.Tensor) – input data

  • y (torch.Tensr) – labels of x

Returns

the loss

update(idxs, x)[source]#

Register the predicted logits to self.predicted_values