1.4. FedMD: Federated Learning with Model Distillation#

This tutorial implements FedMD (Federated Learning with Model Distillation), proposed in https://arxiv.org/abs/1910.03581. AIJack supports both single-process and MPI as the backend of FedMD. While FedAVG communicates local gradients to collaboratively train a model without sharing local datasets, malicious servers might be able to recover the training data from the shared gradient (see Gradient-based Model Inversion Attack against Federated Learning for the detail). In addition, sending and receiving gradients of the model requires much communication power. To solve these challenges, FedMD communicates not gradients but predicted logits on the global dataset and uses the model-distillation method to share each party’s knowledge.

1.4.1. Single Process#

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedmd import FedMDAPI, FedMDClient, FedMDServer
from aijack.utils import NumpyDataset
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            NumpyDataset(
                x=dataset.data.numpy(),
                y=dataset.targets.numpy(),
                transform=transform,
                return_idx=True,
            ),
            batch_size=training_batch_size,
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(
            NumpyDataset(
                x=dataset.data.numpy(),
                y=dataset.targets.numpy(),
                transform=transform,
                return_idx=True,
            ),
            batch_size=test_batch_size,
        )
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output
dataloaders = [prepare_dataloader(client_size + 1, c) for c in range(client_size + 1)]
public_dataloader = dataloaders[0]
local_dataloaders = dataloaders[1:]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
clients = [
    FedMDClient(Net().to(device), public_dataloader, output_dim=10, user_id=c)
    for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedMDServer(clients, Net().to(device))

api = FedMDAPI(
    server,
    clients,
    public_dataloader,
    local_dataloaders,
    F.nll_loss,
    local_optimizers,
    test_dataloader,
    num_communication=2,
)
log = api.run()
epoch 1 (public - pretrain): [1.4732259569076684, 1.509599570077829]
acc on validation dataset:  {'clients_score': [0.7988, 0.7907]}
epoch 1 (local - pretrain): [0.8319099252216351, 0.8403522926397597]
acc on validation dataset:  {'clients_score': [0.8431, 0.8406]}
epoch 1, client 0: 248.21629917621613
epoch 1, client 1: 269.46992498636246
epoch=1 acc on local datasets:  {'clients_score': [0.84605, 0.85175]}
epoch=1 acc on public dataset:  {'clients_score': [0.84925, 0.8516]}
epoch=1 acc on validation dataset:  {'clients_score': [0.8568, 0.8594]}
epoch 2, client 0: 348.2699541449547
epoch 2, client 1: 364.1900661587715
epoch=2 acc on local datasets:  {'clients_score': [0.8508, 0.85555]}
epoch=2 acc on public dataset:  {'clients_score': [0.85395, 0.8567]}
epoch=2 acc on validation dataset:  {'clients_score': [0.8598, 0.8641]}

1.4.2. MPI#

You can execute FedMD with MPI-backend via MPIFedMDClientManager, MPIFedMDServerManager, and MPIFedMDAPI.

%%writefile mpi_fedmd.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedmd import FedMDAPI, FedMDClient, FedMDServer
from aijack.collaborative.fedmd import MPIFedMDAPI, MPIFedMDClientManager, MPIFedMDServerManager
from aijack.utils import NumpyDataset, accuracy_torch_dataloader

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 2
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
             batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
                                                  batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    public_dataloader = prepare_dataloader(size - 1, 0, train=True)

    if myid == 0:
        dataloader = prepare_dataloader(size + 1, myid+1, train=False)
        client_ids = list(range(1, size))
        mpi_manager = MPIFedMDServerManager()
        MPIFedMDServer = mpi_manager.attach(FedMDServer)
        server = MPIFedMDServer(comm, [1, 2], model)
        api = MPIFedMDAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_communication=num_rounds,
            device=device
        )
    else:
        dataloader = prepare_dataloader(size + 1, myid + 1, train=True)
        mpi_manager = MPIFedMDClientManager()
        MPIFedMDClient = mpi_manager.attach(FedMDClient)
        client = MPIFedMDClient(comm, model, public_dataloader, output_dim=10, user_id=myid)
        api = MPIFedMDAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            public_dataloader,
            num_communication=num_rounds,
            device=device
        )

    api.run()

    if myid != 0:
      print(f"client_id={myid}: Accuracy on local dataset is ", accuracy_torch_dataloader(client, dataloader))


if __name__ == "__main__":
    main()
Overwriting mpi_fedmd.py
!sudo mpiexec -np 3 --allow-run-as-root python mpi_fedmd.py
client_id=2: Accuracy on local dataset is  0.8587333333333333
client_id=1: Accuracy on local dataset is  0.8579333333333333