Source code for aijack.collaborative.fedmd.api

import copy

import torch

from ..core.api import BaseFedAPI, BaseFLKnowledgeDistillationAPI


[docs]class FedMDAPI(BaseFLKnowledgeDistillationAPI): """Implementation of `Fedmd: Heterogenous federated learning via model distillation`""" def __init__( self, server, clients, public_dataloader, local_dataloaders, criterion, client_optimizers, validation_dataloader=None, server_optimizer=None, num_communication=1, device="cpu", consensus_epoch=1, revisit_epoch=1, transfer_epoch_public=1, transfer_epoch_private=1, server_training_epoch=1, custom_action=lambda x: x, ): super().__init__( server, clients, public_dataloader, local_dataloaders, validation_dataloader, criterion, num_communication, device, ) self.client_optimizers = client_optimizers self.server_optimizer = server_optimizer self.consensus_epoch = consensus_epoch self.revisit_epoch = revisit_epoch self.transfer_epoch_public = transfer_epoch_public self.transfer_epoch_private = transfer_epoch_private self.server_training_epoch = server_training_epoch self.custom_action = custom_action self.epoch = 0
[docs] def train_server(self): if self.server_optimizer is None: return 0.0 running_loss = 0.0 for data in self.public_dataloader: _, x, y = data x = x.to(self.device) y = y.to(self.device).to(torch.int64) self.server_optimizer.zero_grad() loss = self.criterion(self.server(x), y) loss.backward() self.server_optimizer.step() running_loss += loss.item() running_loss /= len(self.public_dataloader) return running_loss
[docs] def transfer_phase(self, logging): # Transfer for i in range(1, self.transfer_epoch_public + 1): loss_public = self.train_client(public=True) print(f"epoch {i} (public - pretrain): {loss_public}") logging["loss_client_public_dataset_transfer"].append(loss_public) if self.validation_dataloader is not None: acc_val = self.score( self.validation_dataloader, self.server_optimizer is None ) print("acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) for i in range(1, self.transfer_epoch_private + 1): loss_local = self.train_client(public=False) print(f"epoch {i} (local - pretrain): {loss_local}") logging["loss_client_local_dataset_transfer"].append(loss_local) if self.validation_dataloader is not None: acc_val = self.score( self.validation_dataloader, self.server_optimizer is None ) print("acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) return logging
[docs] def digest_phase(self, i, logging): temp_consensus_loss = [] for j, client in enumerate(self.clients): for _ in range(self.consensus_epoch): consensus_loss = client.approach_consensus(self.client_optimizers[j]) print(f"epoch {i}, client {j}: {consensus_loss}") temp_consensus_loss.append(consensus_loss) logging["loss_client_consensus"].append(temp_consensus_loss) return logging
[docs] def revisit_phase(self, logging): for _ in range(self.revisit_epoch): loss_local_revisit = self.train_client(public=False) logging["loss_client_revisit"].append(loss_local_revisit) return logging
[docs] def server_side_training(self, logging): # Train a server-side model if it exists (different from the original paper) for _ in range(self.server_training_epoch): loss_server_public = self.train_server() logging["loss_server_public_dataset"].append(loss_server_public) return logging
[docs] def evaluation(self, i, logging): acc_on_local_dataset = self.local_score() print(f"epoch={i} acc on local datasets: ", acc_on_local_dataset) logging["acc_local"].append(acc_on_local_dataset) acc_pub = self.score(self.public_dataloader, self.server_optimizer is None) print(f"epoch={i} acc on public dataset: ", acc_pub) logging["acc_pub"].append(copy.deepcopy(acc_pub)) # evaluation if self.validation_dataloader is not None: acc_val = self.score( self.validation_dataloader, self.server_optimizer is None ) print(f"epoch={i} acc on validation dataset: ", acc_val) logging["acc_val"].append(copy.deepcopy(acc_val)) return logging
[docs] def run(self): logging = { "loss_client_local_dataset_transfer": [], "loss_client_public_dataset_transfer": [], "loss_client_consensus": [], "loss_client_revisit": [], "loss_server_public_dataset": [], "acc_local": [], "acc_pub": [], "acc_val": [], } logging = self.transfer_phase(logging) for i in range(1, self.num_communication + 1): self.epoch = i self.server.action() logging = self.digest_phase(i, logging) logging = self.revisit_phase(logging) logging = self.server_side_training(logging) logging = self.evaluation(i, logging) self.custom_action(self) return logging
[docs]class MPIFedMDAPI(BaseFedAPI): def __init__( self, comm, party, is_server, criterion, local_optimizer=None, local_dataloader=None, public_dataloader=None, num_communication=1, local_epoch=1, consensus_epoch=1, revisit_epoch=1, transfer_epoch_public=1, transfer_epoch_private=1, custom_action=lambda x: x, device="cpu", ): self.comm = comm self.party = party self.is_server = is_server self.criterion = criterion self.local_optimizer = local_optimizer self.local_dataloader = local_dataloader self.public_dataloader = public_dataloader self.num_communication = num_communication self.local_epoch = local_epoch self.consensus_epoch = consensus_epoch self.revisit_epoch = revisit_epoch self.transfer_epoch_public = transfer_epoch_public self.transfer_epoch_private = transfer_epoch_private self.custom_action = custom_action self.device = device
[docs] def run(self): # Transfer phase if not self.is_server: self.local_train(epoch=self.transfer_epoch_public, public=True) self.local_train(epoch=self.transfer_epoch_private, public=False) for _ in range(self.num_communication): # Updata global logits self.party.action() self.comm.Barrier() self.digest_phase() self.revisit_phase() self.custom_action(self) self.comm.Barrier()
[docs] def digest_phase(self): if not self.is_server: for _ in range(1, self.consensus_epoch): self.party.approach_consensus(self.local_optimizer)
[docs] def revisit_phase(self): if not self.is_server: for _ in range(self.revisit_epoch): self.local_train(public=False)
[docs] def local_train(self, epoch=1, public=True): trainloader = self.public_dataloader if public else self.local_dataloader self.party.local_train(epoch, self.criterion, trainloader, self.local_optimizer)