1.2. FedAVG with Paillier Encryption#

Homomorphic Encryption is one type of encryption scheme where you can execute some arithmetic operations on cipher texts. For example, Paillier Encryption Scheme has the following properties;

$$ \begin{align} &\mathcal{D}(\mathcal{E}(x) + \mathcal{E}(y)) = x + y \newline &\mathcal{D}(\mathcal{E}(x) + y) = x + y \newline &\mathcal{D}(\mathcal{E}(x) * y) = x * y \end{align} $$

, where $ \mathcal{E} $ and $ \mathcal{D} $ represent encryption and decryption, respectively.

Recall that the server in FedAVG averages the received gradients to update the global model.

$$ w_{t} \leftarrow w_{t - 1} - \eta \sum_{c=1}^{C} \frac{n_{c}}{N} \nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c}) $$

To mitigate the potential private information leakage from the gradient, one option for the client is to encrypt the gradient with Paillier Encryption Scheme.

$$ w_{t} \leftarrow w_{t - 1} - \eta \sum_{c=1}^{C} \frac{n_{c}}{N} \mathcal{E} (\nabla \mathcal{l}(w_{t - 1}, X_{c}, Y_{c})) $$

The details procedure of Federated Learning with Paillier Encryption is as follows:

1. The central server initializes the global model.
2. Clients publish and share private and public keys.
3. The server distributes the global model to each client.
4. Except for the first round, each client decrypts the global model.
5. Each client locally calculates the gradient of the loss function on their dataset.
6. Each client encrypts the gradient and sends it to the server.
7. The server aggregates the received gradients with some method (e.g., average) and updates the global model with the aggregated gradient.
8. Repeat 3 ~ 7 until converge.

1.2.1. Single Process#

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI
from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator


def evaluate_gloal_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

    return _evaluate_global_model
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw

1.2.1.1. Federated Learning with Paillier Encryption#

keygenerator = PaillierKeyGenerator(64)
pk, sk = keygenerator.generate_keypair()

manager = PaillierGradientClientManager(pk, sk)
PaillierGradFedAVGClient = manager.attach(FedAVGClient)

clients = [
    PaillierGradFedAVGClient(Net().to(device), user_id=c, server_side_update=False)
    for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device), server_side_update=False)

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader, 0),
)
api.run()
communication 0, epoch 0: client-1 0.019623182541131972
communication 0, epoch 0: client-2 0.019723439224561056
/usr/local/lib/python3.8/dist-packages/aijack/defense/paillier/torch_wrapper.py:70: RuntimeWarning: invalid value encountered in add
  input._paillier_np_array + other.detach().cpu().numpy()
Test set: Average loss: 0.5059418523311615, Accuracy: 84.25
communication 1, epoch 0: client-1 0.00757011673549811
communication 1, epoch 0: client-2 0.007764058018724124
Test set: Average loss: 0.4435205452442169, Accuracy: 87.55
communication 2, epoch 0: client-1 0.006700039783120155
communication 2, epoch 0: client-2 0.0069033132503430045
Test set: Average loss: 0.40868335359096525, Accuracy: 87.98
communication 3, epoch 0: client-1 0.006276320548355579
communication 3, epoch 0: client-2 0.006470099781453609
Test set: Average loss: 0.3903049408197403, Accuracy: 89.17
communication 4, epoch 0: client-1 0.005988184402386347
communication 4, epoch 0: client-2 0.0061936042274038
Test set: Average loss: 0.37640745265483855, Accuracy: 89.14