1.4. FedMD: Federated Learning with Model Distillation#
This tutorial implements FedMD (Federated Learning with Model Distillation), proposed in https://arxiv.org/abs/1910.03581. AIJack supports both single-process and MPI as the backend of FedMD. While FedAVG communicates local gradients to collaboratively train a model without sharing local datasets, malicious servers might be able to recover the training data from the shared gradient (see Gradient-based Model Inversion Attack against Federated Learning for the detail). In addition, sending and receiving gradients of the model requires much communication power. To solve these challenges, FedMD communicates not gradients but predicted logits on the global dataset and uses the model-distillation method to share each party’s knowledge.
1.4.1. Single Process#
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative.fedmd import FedMDAPI, FedMDClient, FedMDServer
from aijack.utils import NumpyDataset
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
NumpyDataset(
x=dataset.data.numpy(),
y=dataset.targets.numpy(),
transform=transform,
return_idx=True,
),
batch_size=training_batch_size,
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
NumpyDataset(
x=dataset.data.numpy(),
y=dataset.targets.numpy(),
transform=transform,
return_idx=True,
),
batch_size=test_batch_size,
)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
dataloaders = [prepare_dataloader(client_size + 1, c) for c in range(client_size + 1)]
public_dataloader = dataloaders[0]
local_dataloaders = dataloaders[1:]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
clients = [
FedMDClient(Net().to(device), public_dataloader, output_dim=10, user_id=c)
for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]
server = FedMDServer(clients, Net().to(device))
api = FedMDAPI(
server,
clients,
public_dataloader,
local_dataloaders,
F.nll_loss,
local_optimizers,
test_dataloader,
num_communication=2,
)
log = api.run()
epoch 1 (public - pretrain): [1.4732259569076684, 1.509599570077829]
acc on validation dataset: {'clients_score': [0.7988, 0.7907]}
epoch 1 (local - pretrain): [0.8319099252216351, 0.8403522926397597]
acc on validation dataset: {'clients_score': [0.8431, 0.8406]}
epoch 1, client 0: 248.21629917621613
epoch 1, client 1: 269.46992498636246
epoch=1 acc on local datasets: {'clients_score': [0.84605, 0.85175]}
epoch=1 acc on public dataset: {'clients_score': [0.84925, 0.8516]}
epoch=1 acc on validation dataset: {'clients_score': [0.8568, 0.8594]}
epoch 2, client 0: 348.2699541449547
epoch 2, client 1: 364.1900661587715
epoch=2 acc on local datasets: {'clients_score': [0.8508, 0.85555]}
epoch=2 acc on public dataset: {'clients_score': [0.85395, 0.8567]}
epoch=2 acc on validation dataset: {'clients_score': [0.8598, 0.8641]}
1.4.2. MPI#
You can execute FedMD with MPI-backend via MPIFedMDClientManager
, MPIFedMDServerManager
, and MPIFedMDAPI
.
%%writefile mpi_fedmd.py
import random
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative.fedmd import FedMDAPI, FedMDClient, FedMDServer
from aijack.collaborative.fedmd import MPIFedMDAPI, MPIFedMDClientManager, MPIFedMDServerManager
from aijack.utils import NumpyDataset, accuracy_torch_dataloader
logger = getLogger(__name__)
training_batch_size = 64
test_batch_size = 64
num_rounds = 2
lr = 0.001
seed = 0
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
batch_size=test_batch_size)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
def main():
fix_seed(seed)
comm = MPI.COMM_WORLD
myid = comm.Get_rank()
size = comm.Get_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)
public_dataloader = prepare_dataloader(size - 1, 0, train=True)
if myid == 0:
dataloader = prepare_dataloader(size + 1, myid+1, train=False)
client_ids = list(range(1, size))
mpi_manager = MPIFedMDServerManager()
MPIFedMDServer = mpi_manager.attach(FedMDServer)
server = MPIFedMDServer(comm, [1, 2], model)
api = MPIFedMDAPI(
comm,
server,
True,
F.nll_loss,
None,
None,
num_communication=num_rounds,
device=device
)
else:
dataloader = prepare_dataloader(size + 1, myid + 1, train=True)
mpi_manager = MPIFedMDClientManager()
MPIFedMDClient = mpi_manager.attach(FedMDClient)
client = MPIFedMDClient(comm, model, public_dataloader, output_dim=10, user_id=myid)
api = MPIFedMDAPI(
comm,
client,
False,
F.nll_loss,
optimizer,
dataloader,
public_dataloader,
num_communication=num_rounds,
device=device
)
api.run()
if myid != 0:
print(f"client_id={myid}: Accuracy on local dataset is ", accuracy_torch_dataloader(client, dataloader))
if __name__ == "__main__":
main()
Overwriting mpi_fedmd.py
!sudo mpiexec -np 3 --allow-run-as-root python mpi_fedmd.py
client_id=2: Accuracy on local dataset is 0.8587333333333333
client_id=1: Accuracy on local dataset is 0.8579333333333333