1.3. FedAVG with Sparse Gradient#

Federated Learning with sparse gradient is a technique that aims to reduce the amount of data exchanged between clients and the central server during the training process, while still maintaining the accuracy of the global model. In this technique, each client only sends a sparse representation of the gradient calculated on its local data to the server, rather than the full gradient. This reduces the amount of data that needs to be exchanged, which can be especially useful in situations where the data is sensitive or the communication bandwidth is limited.

The sparse representation of the gradient can be achieved by applying a sparsifying transformation, such as thresholding or quantization, to the gradients before sending them to the server. The server then aggregates the sparse gradients and applies the inverse transformation to obtain the full gradients. In this tutorial, we adop top-k sparse gradient, where each client only sends the top-k largest absolute values of the gradient to the server.

This approach can be beneficial in terms of privacy and communication efficiency, but it could also decrease the performance of the model. Furthermore, the sparsity of the gradients needs to be balanced with the accuracy of the model, as too much sparsity will result in a less accurate model.

1.3.1. Download Dataset#

import random

import numpy as np
import torch
from torchvision import datasets, transforms
training_batch_size = 64
test_batch_size = 64
seed = 0
client_size = 2
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw

1.3.2. Top-K Sparse Gradient with MPI backend#

%%writefile mpi_FedAVG_sparse.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
from aijack.defense.sparse import (
    SparseGradientClientManager,
    SparseGradientServerManager,
)

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    sg_client_manager = SparseGradientClientManager(k=0.03)
    mpi_client_manager = MPIFedAVGClientManager()
    SparseGradientFedAVGClient = sg_client_manager.attach(FedAVGClient)
    MPISparseGradientFedAVGClient = mpi_client_manager.attach(SparseGradientFedAVGClient)

    sg_server_manager = SparseGradientServerManager()
    mpi_server_manager = MPIFedAVGServerManager()
    SparseGradientFedAVGServer = sg_server_manager.attach(FedAVGServer)
    MPISparseGradientFedAVGServer = mpi_server_manager.attach(SparseGradientFedAVGServer)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPISparseGradientFedAVGServer(comm, [1, 2], model)
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device,
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPISparseGradientFedAVGClient(comm, model, user_id=myid)
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device,
        )

    api.run()


if __name__ == "__main__":
    main()
Writing mpi_FedAVG_sparse.py
!sudo mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG_sparse.py
communication 0, epoch 0: client-2 0.02008056694070498
communication 0, epoch 0: client-3 0.019996537216504413
Round: 1, Test set: Average loss: 1.7728474597930908, Accuracy: 38.47
communication 1, epoch 0: client-3 0.016255500958363214
communication 1, epoch 0: client-2 0.016343721010287603
Round: 2, Test set: Average loss: 1.4043720769882202, Accuracy: 60.5
communication 2, epoch 0: client-2 0.014353630113601685
communication 2, epoch 0: client-3 0.014260987114906311
Round: 3, Test set: Average loss: 1.1684634439468384, Accuracy: 70.27
communication 3, epoch 0: client-2 0.013123111790418624
communication 3, epoch 0: client-3 0.013032549581925075
Round: 4, Test set: Average loss: 1.0258800836563111, Accuracy: 75.0
communication 4, epoch 0: client-2 0.012242827371756236
communication 4, epoch 0: client-3 0.012150899289051692
Round: 5, Test set: Average loss: 0.9197616576194764, Accuracy: 77.6