1.3. FedAVG with Sparse Gradient#
Federated Learning with sparse gradient is a technique that aims to reduce the amount of data exchanged between clients and the central server during the training process, while still maintaining the accuracy of the global model. In this technique, each client only sends a sparse representation of the gradient calculated on its local data to the server, rather than the full gradient. This reduces the amount of data that needs to be exchanged, which can be especially useful in situations where the data is sensitive or the communication bandwidth is limited.
The sparse representation of the gradient can be achieved by applying a sparsifying transformation, such as thresholding or quantization, to the gradients before sending them to the server. The server then aggregates the sparse gradients and applies the inverse transformation to obtain the full gradients. In this tutorial, we adop top-k sparse gradient, where each client only sends the top-k largest absolute values of the gradient to the server.
This approach can be beneficial in terms of privacy and communication efficiency, but it could also decrease the performance of the model. Furthermore, the sparsity of the gradients needs to be balanced with the accuracy of the model, as too much sparsity will result in a less accurate model.
1.3.1. Download Dataset#
import random
import numpy as np
import torch
from torchvision import datasets, transforms
training_batch_size = 64
test_batch_size = 64
seed = 0
client_size = 2
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw
1.3.2. Top-K Sparse Gradient with MPI backend#
%%writefile mpi_FedAVG_sparse.py
import random
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms
from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
from aijack.defense.sparse import (
SparseGradientClientManager,
SparseGradientServerManager,
)
logger = getLogger(__name__)
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
def fix_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def prepare_dataloader(num_clients, myid, train=True, path=""):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
if train:
dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
idxs = list(range(len(dataset.data)))
random.shuffle(idxs)
idx = np.array_split(idxs, num_clients, 0)[myid - 1]
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=training_batch_size
)
return train_loader
else:
dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
return test_loader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ln = nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.ln(x.reshape(-1, 28 * 28))
output = F.log_softmax(x, dim=1)
return output
def evaluate_gloal_model(dataloader):
def _evaluate_global_model(api):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(api.device), target.to(api.device)
output = api.party(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(dataloader.dataset)
accuracy = 100.0 * correct / len(dataloader.dataset)
print(
f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
)
return _evaluate_global_model
def main():
fix_seed(seed)
comm = MPI.COMM_WORLD
myid = comm.Get_rank()
size = comm.Get_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)
sg_client_manager = SparseGradientClientManager(k=0.03)
mpi_client_manager = MPIFedAVGClientManager()
SparseGradientFedAVGClient = sg_client_manager.attach(FedAVGClient)
MPISparseGradientFedAVGClient = mpi_client_manager.attach(SparseGradientFedAVGClient)
sg_server_manager = SparseGradientServerManager()
mpi_server_manager = MPIFedAVGServerManager()
SparseGradientFedAVGServer = sg_server_manager.attach(FedAVGServer)
MPISparseGradientFedAVGServer = mpi_server_manager.attach(SparseGradientFedAVGServer)
if myid == 0:
dataloader = prepare_dataloader(size - 1, myid, train=False)
client_ids = list(range(1, size))
server = MPISparseGradientFedAVGServer(comm, [1, 2], model)
api = MPIFedAVGAPI(
comm,
server,
True,
F.nll_loss,
None,
None,
num_rounds,
1,
custom_action=evaluate_gloal_model(dataloader),
device=device,
)
else:
dataloader = prepare_dataloader(size - 1, myid, train=True)
client = MPISparseGradientFedAVGClient(comm, model, user_id=myid)
api = MPIFedAVGAPI(
comm,
client,
False,
F.nll_loss,
optimizer,
dataloader,
num_rounds,
1,
device=device,
)
api.run()
if __name__ == "__main__":
main()
Writing mpi_FedAVG_sparse.py
!sudo mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG_sparse.py
communication 0, epoch 0: client-2 0.02008056694070498
communication 0, epoch 0: client-3 0.019996537216504413
Round: 1, Test set: Average loss: 1.7728474597930908, Accuracy: 38.47
communication 1, epoch 0: client-3 0.016255500958363214
communication 1, epoch 0: client-2 0.016343721010287603
Round: 2, Test set: Average loss: 1.4043720769882202, Accuracy: 60.5
communication 2, epoch 0: client-2 0.014353630113601685
communication 2, epoch 0: client-3 0.014260987114906311
Round: 3, Test set: Average loss: 1.1684634439468384, Accuracy: 70.27
communication 3, epoch 0: client-2 0.013123111790418624
communication 3, epoch 0: client-3 0.013032549581925075
Round: 4, Test set: Average loss: 1.0258800836563111, Accuracy: 75.0
communication 4, epoch 0: client-2 0.012242827371756236
communication 4, epoch 0: client-3 0.012150899289051692
Round: 5, Test set: Average loss: 0.9197616576194764, Accuracy: 77.6