1.1. FedAVG#
In this tutorial, you will learn how to simulate FedAVG, a representative scheme of Federated Learning, with AIJack. You can choose the single process or MPI as the backend. We will also demonstrate that you can add various defense methods to FedAVG with only a few additional lines.
While deep learning achieves substantial success in various areas, training deep learning models requires much data. Thus, acquiring high performance in deep learning while preserving privacy is challenging. One way to solve this problem is Federated Learning, where multiple clients collaboratively train a single global model without sharing their local dataset.
The procedure of typical Federated Learning is as follows:
1. The central server initializes the global model.
2. The server distributes global model to each client.
3. Each client locally calculates the gradient of the loss function on their dataset.
4. Each client sends the gradient to the server.
5. The server aggregates the received gradients with some method (e.g., average) and updates the global model with the aggregated gradient.
6. Repeat 2 ~ 5 until converge.
The mathematical notification when the aggregation is the weighted average is as follows:
$$ w_{t} \leftarrow w_{t - 1} - \eta \sum_{c=1}^{C} \frac{n_{c}}{N} \nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c}) $$
, where $w_{t}$ is the parameter of the global model in $t$-th round, $\nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c})$ is the gradient calculated on $c$-th client’s dataset $((X_{c}, Y_{c}))$, $n_{c}$ is the number of $c$-th client’s dataset, and N is the total number of samples.
1.1.1. Single Process#
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI
def evaluate_gloal_model(dataloader, client_id=-1):
def _evaluate_global_model(api):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(api.device), target.to(api.device)
if client_id == -1:
output = api.server(data)
else:
output = api.clients[client_id](data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(dataloader.dataset)
accuracy = 100.0 * correct / len(dataloader.dataset)
print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")
return _evaluate_global_model
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw
clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]
server = FedAVGServer(clients, Net().to(device))
api = FedAVGAPI(
server,
clients,
criterion,
local_optimizers,
local_dataloaders,
num_communication=num_rounds,
custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
communication 0, epoch 0: client-1 0.019623182541131972
communication 0, epoch 0: client-2 0.019723439224561056
Test set: Average loss: 0.7824367880821228, Accuracy: 83.71
communication 1, epoch 0: client-1 0.01071754728158315
communication 1, epoch 0: client-2 0.010851142065723737
Test set: Average loss: 0.58545467877388, Accuracy: 86.49
communication 2, epoch 0: client-1 0.008766427374879518
communication 2, epoch 0: client-2 0.00891655088464419
Test set: Average loss: 0.507768925857544, Accuracy: 87.54
communication 3, epoch 0: client-1 0.007839484961827596
communication 3, epoch 0: client-2 0.00799967499623696
Test set: Average loss: 0.46477557654380797, Accuracy: 88.25
communication 4, epoch 0: client-1 0.0072782577464977904
communication 4, epoch 0: client-2 0.007445397683481375
Test set: Average loss: 0.436919868183136, Accuracy: 88.63
1.1.2. MPI#
%%writefile mpi_FedAVG.py
import random
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
logger = getLogger(__name__)
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
def evaluate_gloal_model(dataloader):
def _evaluate_global_model(api):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(api.device), target.to(api.device)
output = api.party(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(dataloader.dataset)
accuracy = 100.0 * correct / len(dataloader.dataset)
print(
f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
)
return _evaluate_global_model
def main():
fix_seed(seed)
comm = MPI.COMM_WORLD
myid = comm.Get_rank()
size = comm.Get_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)
mpi_client_manager = MPIFedAVGClientManager()
mpi_server_manager = MPIFedAVGServerManager()
MPIFedAVGClient = mpi_client_manager.attach(FedAVGClient)
MPIFedAVGServer = mpi_server_manager.attach(FedAVGServer)
if myid == 0:
dataloader = prepare_dataloader(size - 1, myid, train=False)
client_ids = list(range(1, size))
server = MPIFedAVGServer(comm, [1, 2], model)
api = MPIFedAVGAPI(
comm,
server,
True,
F.nll_loss,
None,
None,
num_rounds,
1,
custom_action=evaluate_gloal_model(dataloader),
device=device
)
else:
dataloader = prepare_dataloader(size - 1, myid, train=True)
client = MPIFedAVGClient(comm, model, user_id=myid)
api = MPIFedAVGAPI(
comm,
client,
False,
F.nll_loss,
optimizer,
dataloader,
num_rounds,
1,
device=device
)
api.run()
if __name__ == "__main__":
main()
Writing mpi_FedAVG.py
!sudo mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py
communication 0, epoch 0: client-3 0.019996537216504413
communication 0, epoch 0: client-2 0.02008056694070498
Round: 1, Test set: Average loss: 0.7860309104919434, Accuracy: 82.72
communication 1, epoch 0: client-3 0.010822976715366046
communication 1, epoch 0: client-2 0.010937693453828494
Round: 2, Test set: Average loss: 0.5885528886795044, Accuracy: 86.04
communication 2, epoch 0: client-2 0.008990796900788942
communication 2, epoch 0: client-3 0.008850129560629527
Round: 3, Test set: Average loss: 0.5102099328994751, Accuracy: 87.33
communication 3, epoch 0: client-3 0.00791173183619976
communication 3, epoch 0: client-2 0.008069112183650334
Round: 4, Test set: Average loss: 0.4666414333820343, Accuracy: 88.01
communication 4, epoch 0: client-2 0.007512268128991127
communication 4, epoch 0: client-3 0.007343090359369914
Round: 5, Test set: Average loss: 0.4383064950466156, Accuracy: 88.65