5.1. Poisoning Attack against Federated Learning#

This tutorial demonstrates that malicious clients effectively decrease the performance of the final global model by injecting noise into the local update or data.

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from aijack.attack.poison import HistoryAttackClientWrapper
from aijack.attack.poison import LabelFlipAttackClientManager
from aijack.attack.poison import MAPFClientWrapper
from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI


def evaluate_gloal_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

    return _evaluate_global_model
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw

5.1.1. History Attack#

manager = HistoryAttackClientWrapper(lam=3)
HistoryAttackFedAVGClient = manager.attach(FedAVGClient)
clients = [
    HistoryAttackFedAVGClient(
        Net(),
        user_id=0,
        lr=lr,
    ),
    FedAVGClient(
        Net(),
        user_id=1,
        lr=lr,
    ),
]

local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
communication 0, epoch 0: client-1 0.019623182545105616
communication 0, epoch 0: client-2 0.019723439192771912
Test set: Average loss: 6.538579542136137, Accuracy: 82.57
communication 1, epoch 0: client-1 0.1011283678372701
communication 1, epoch 0: client-2 0.10503993360201518
Test set: Average loss: 109.3780958984375, Accuracy: 24.71
communication 2, epoch 0: client-1 1.307175351079305
communication 2, epoch 0: client-2 1.322490109125773
Test set: Average loss: 514.1672515625, Accuracy: 59.83
communication 3, epoch 0: client-1 7.669983786519368
communication 3, epoch 0: client-2 7.649992772420247
Test set: Average loss: 440.93763037109375, Accuracy: 46.56
communication 4, epoch 0: client-1 6.582273025512695
communication 4, epoch 0: client-2 6.532691622924805
Test set: Average loss: 616.4571529296875, Accuracy: 40.87

5.1.2. Label Flip Attack#

manager = LabelFlipAttackClientManager(victim_label=0, target_label=1)
LabelFlipAttackFedAVGClient = manager.attach(FedAVGClient)
clients = [
    LabelFlipAttackFedAVGClient(
        Net(),
        user_id=0,
        lr=lr,
    ),
    FedAVGClient(
        Net(),
        user_id=1,
        lr=lr,
    ),
]

local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
communication 0, epoch 0: client-1 0.020543035099903743
communication 0, epoch 0: client-2 0.020125101908047994
Test set: Average loss: 28.122399871826172, Accuracy: 73.83
communication 1, epoch 0: client-1 0.5485365001996358
communication 1, epoch 0: client-2 0.4188099824587504
Test set: Average loss: 364.91654228515625, Accuracy: 39.13
communication 2, epoch 0: client-1 5.370502290852865
communication 2, epoch 0: client-2 5.275297354125977
Test set: Average loss: 1108.481857421875, Accuracy: 35.73
communication 3, epoch 0: client-1 15.66210668334961
communication 3, epoch 0: client-2 16.781931443277994
Test set: Average loss: 1227.238296875, Accuracy: 33.2
communication 4, epoch 0: client-1 16.553591099039714
communication 4, epoch 0: client-2 18.498205289713542
Test set: Average loss: 1096.7404470703125, Accuracy: 42.83

5.1.3. MAPF#

manager = MAPFClientWrapper(lam=3)
MAPFFedAVGClient = manager.attach(FedAVGClient)
clients = [
    MAPFFedAVGClient(
        Net(),
        user_id=0,
        lr=lr,
    ),
    FedAVGClient(
        Net(),
        user_id=1,
        lr=lr,
    ),
]

local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
communication 0, epoch 0: client-1 0.019650927847623824
communication 0, epoch 0: client-2 0.019755615478754044
Test set: Average loss: 6.351612661059061, Accuracy: 83.04
communication 1, epoch 0: client-1 0.10415176281332969
communication 1, epoch 0: client-2 0.10801099200248718
Test set: Average loss: 64.09054548339844, Accuracy: 35.71
communication 2, epoch 0: client-1 0.6900041089375814
communication 2, epoch 0: client-2 0.6906570717493693
Test set: Average loss: 423.16109165039063, Accuracy: 50.94
communication 3, epoch 0: client-1 6.43880789159139
communication 3, epoch 0: client-2 6.371651240030925
Test set: Average loss: 596.9880190429687, Accuracy: 41.92
communication 4, epoch 0: client-1 8.962839545694987
communication 4, epoch 0: client-2 8.890711385091146
Test set: Average loss: 597.5907002929688, Accuracy: 39.4