3.1.5. aijack.collaborative.fedgems package#
3.1.5.1. Submodules#
3.1.5.2. aijack.collaborative.fedgems.api module#
- class aijack.collaborative.fedgems.api.FedGEMSAPI(server, clients, public_dataloader, local_dataloaders, criterion, server_optimizer, client_optimizers, validation_dataloader=None, num_communication=10, epoch_client_on_localdataset=10, epoch_client_on_publicdataset=10, epoch_server_on_publicdataset=10, device='cpu', custom_action=<function FedGEMSAPI.<lambda>>)[source]#
Bases:
aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI
API of FedGEMSAPI.
- Parameters
server (FedGEMSServer) – a server.
clients (List[FedGEMSClient]) – a list of clients.
public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.
local_dataloaders (List[torch.utils.data.DataLoader]) – a list of dataloaders of the local datasets.
validation_dataloader (torch.utils.data.DataLoader) – a dataloader of the validation dataset.
criterion (function)) – a loss function
server_optimizer (torch.optim.Optimizer) – an optimizer for the global model
client_optimizers (List[torch.optim.Optimizer]) – a list of optimizers for the local models
num_communication (int, optional) – the number of communications. Defaults to 10.
epoch_client_on_localdataset (int, optional) – the number of epochs of client-side training on the private datasets. Defaults to 10.
epoch_client_on_publicdataset (int, optional) – the number of epochs of client-side training on the public datasets. Defaults to 10.
epoch_server_on_publicdataset (int, optional) – the number of epochs of server-side training on the public dataset. Defaults to 10.
device (str, optional) – device type. Defaults to “cpu”.
custom_action (function, optional) – custom function which this api calls at the end of every communication. Defaults to lambda x:x.
3.1.5.3. aijack.collaborative.fedgems.client module#
3.1.5.4. aijack.collaborative.fedgems.server module#
- class aijack.collaborative.fedgems.server.FedGEMSServer(clients, global_model, len_public_dataloader, output_dim=1, self_evaluation_func=None, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), server_id=0, lr=0.1, epsilon=0.75, device='cpu')[source]#
3.1.5.5. Module contents#
Implementation of Cheng, Sijie, et al. “FedGEMS: Federated Learning of Larger Server Models via Selective Knowledge Fusion.” arXiv preprint arXiv:2110.11027 (2021).
- class aijack.collaborative.fedgems.FedGEMSAPI(server, clients, public_dataloader, local_dataloaders, criterion, server_optimizer, client_optimizers, validation_dataloader=None, num_communication=10, epoch_client_on_localdataset=10, epoch_client_on_publicdataset=10, epoch_server_on_publicdataset=10, device='cpu', custom_action=<function FedGEMSAPI.<lambda>>)[source]#
Bases:
aijack.collaborative.core.api.BaseFLKnowledgeDistillationAPI
API of FedGEMSAPI.
- Parameters
server (FedGEMSServer) – a server.
clients (List[FedGEMSClient]) – a list of clients.
public_dataloader (torch.utils.data.DataLoader) – a dataloader of the public dataset.
local_dataloaders (List[torch.utils.data.DataLoader]) – a list of dataloaders of the local datasets.
validation_dataloader (torch.utils.data.DataLoader) – a dataloader of the validation dataset.
criterion (function)) – a loss function
server_optimizer (torch.optim.Optimizer) – an optimizer for the global model
client_optimizers (List[torch.optim.Optimizer]) – a list of optimizers for the local models
num_communication (int, optional) – the number of communications. Defaults to 10.
epoch_client_on_localdataset (int, optional) – the number of epochs of client-side training on the private datasets. Defaults to 10.
epoch_client_on_publicdataset (int, optional) – the number of epochs of client-side training on the public datasets. Defaults to 10.
epoch_server_on_publicdataset (int, optional) – the number of epochs of server-side training on the public dataset. Defaults to 10.
device (str, optional) – device type. Defaults to “cpu”.
custom_action (function, optional) – custom function which this api calls at the end of every communication. Defaults to lambda x:x.
- class aijack.collaborative.fedgems.FedGEMSClient(model, user_id=0, lr=0.1, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), epsilon=0.75, round_decimal=None, device='cpu')[source]#
- class aijack.collaborative.fedgems.FedGEMSServer(clients, global_model, len_public_dataloader, output_dim=1, self_evaluation_func=None, base_loss_func=CrossEntropyLoss(), kldiv_loss_func=KLDivLoss(), server_id=0, lr=0.1, epsilon=0.75, device='cpu')[source]#