6.1. Backdoor Attack against Federated Learning#

Bagdasaryan, Eugene, et al. "How to backdoor federated learning." International conference on artificial intelligence and statistics. PMLR, 2020.

!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00233/CNAE-9.data
--2023-11-06 18:58:33--  https://archive.ics.uci.edu/ml/machine-learning-databases/00233/CNAE-9.data
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘CNAE-9.data’

CNAE-9.data             [            <=>     ]   1.76M   660KB/s    in 2.7s    

2023-11-06 18:58:36 (660 KB/s) - ‘CNAE-9.data’ saved [1851120]
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
import random
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from aijack.attack.backdoor.modelreplacement import ModelReplacementAttackClientManager
from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI

batch_size = 1
num_rounds = 15
lr = 0.0001
criterion = nn.CrossEntropyLoss()

torch.manual_seed(42)
random.seed(42)
def evaluate_gloal_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                data[:, 0] = -1  # inject backdoor
                target = torch.zeros_like(target)
                if client_id == -1:
                    output = api.server(data)
                else:
                    output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(f"Poisoned set: Average loss: {test_loss}, Accuracy: {accuracy}")

    return _evaluate_global_model
df = pd.read_csv("CNAE-9.data", header=None)
X = df[range(1, df.shape[1])].values
y = df[[0]].values.reshape(-1) - 1

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, shuffle=True
)

X_train = torch.Tensor(X_train)
X_test = torch.Tensor(X_test)
y_train = torch.Tensor(y_train).to(torch.long)
y_test = torch.Tensor(y_test).to(torch.long)

poisoned_idx = random.sample(
    list(range(int(X_train.shape[0] / 2))), int(0.1 * int(X_train.shape[0] / 2))
)
X_train[poisoned_idx, 0] = -1  # inject backdoor
y_train[poisoned_idx] = 0

trainset_1 = TensorDataset(
    X_train[: int(X_train.shape[0] / 2)], y_train[: int(X_train.shape[0] / 2)]
)
trainset_2 = TensorDataset(
    X_train[int(X_train.shape[0] / 2) :], y_train[int(X_train.shape[0] / 2) :]
)
testdataset = TensorDataset(X_test, y_test)

trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size)
local_dataloaders = [trainloader_1, trainloader_2]
test_dataloader = torch.utils.data.DataLoader(testdataset, batch_size=batch_size)
np.unique(y_train, return_counts=True)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
 array([132,  99,  89,  86,  84,  94,  90,  96,  94]))
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(df.shape[1] - 1, int(df.shape[1] / 2))
        self.fc2 = nn.Linear(int(df.shape[1] / 2), 9)

    def forward(self, x):
        out = self.fc1(x).relu()
        out = self.fc2(out)
        return out


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(df.shape[1] - 1, 9)

    def forward(self, x):
        out = self.fc1(x)
        return out
manager = ModelReplacementAttackClientManager(alpha=0.99, gamma=1)
ModelReplacementAttackFedAVGClient = manager.attach(FedAVGClient)
clients = [
    ModelReplacementAttackFedAVGClient(
        Net(),
        user_id=0,
        lr=lr,
    ),
    FedAVGClient(
        Net(),
        user_id=1,
        lr=lr,
    ),
]

local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net())

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader),
)
api.run()
communication 0, epoch 0: client-1 2.1662394603093467
communication 0, epoch 0: client-2 2.1922271141299494
Test set: Average loss: -5.4690259304587485, Accuracy: 80.55555555555556
Poisoned set: Average loss: -3.702957514811445, Accuracy: 38.425925925925924
communication 1, epoch 0: client-1 0.6799312920597471
communication 1, epoch 0: client-2 0.5005580799708169
Test set: Average loss: -5.493952719701661, Accuracy: 89.35185185185185
Poisoned set: Average loss: -3.5129815687735877, Accuracy: 33.333333333333336
communication 2, epoch 0: client-1 0.5290323983713447
communication 2, epoch 0: client-2 0.35746124050794126
Test set: Average loss: -5.363429322700809, Accuracy: 70.37037037037037
Poisoned set: Average loss: -7.316975637718484, Accuracy: 74.07407407407408
communication 3, epoch 0: client-1 0.9123193261237467
communication 3, epoch 0: client-2 1.0773419617421665
Test set: Average loss: -6.41383598420631, Accuracy: 74.53703703703704
Poisoned set: Average loss: -5.338864521295936, Accuracy: 25.925925925925927
communication 4, epoch 0: client-1 2.6969062964287533
communication 4, epoch 0: client-2 3.1166031406329857
Test set: Average loss: -6.398117804416904, Accuracy: 68.98148148148148
Poisoned set: Average loss: -11.535380529032814, Accuracy: 80.55555555555556
communication 5, epoch 0: client-1 2.36450203070146
communication 5, epoch 0: client-2 2.8563940612600134
Test set: Average loss: -7.483057366753066, Accuracy: 76.85185185185185
Poisoned set: Average loss: -7.282920123250396, Accuracy: 38.888888888888886
communication 6, epoch 0: client-1 1.8104334803488789
communication 6, epoch 0: client-2 1.665984229675933
Test set: Average loss: -7.133859022072068, Accuracy: 79.16666666666667
Poisoned set: Average loss: -10.943825021938041, Accuracy: 77.31481481481481
communication 7, epoch 0: client-1 1.188665873509647
communication 7, epoch 0: client-2 1.589228773664098
Test set: Average loss: -7.953649083497347, Accuracy: 85.18518518518519
Poisoned set: Average loss: -10.764490551418728, Accuracy: 82.87037037037037
communication 8, epoch 0: client-1 0.3532013624386631
communication 8, epoch 0: client-2 0.38635327980916545
Test set: Average loss: -8.302713400235882, Accuracy: 91.20370370370371
Poisoned set: Average loss: -9.670636004871792, Accuracy: 72.68518518518519
communication 9, epoch 0: client-1 0.16802960019725577
communication 9, epoch 0: client-2 0.12908427009462903
Test set: Average loss: -8.271086318901292, Accuracy: 93.05555555555556
Poisoned set: Average loss: -10.32074565357632, Accuracy: 79.16666666666667
communication 10, epoch 0: client-1 0.11566019391113969
communication 10, epoch 0: client-2 0.1096428572484203
Test set: Average loss: -8.297636664438027, Accuracy: 94.44444444444444
Poisoned set: Average loss: -10.347629803198355, Accuracy: 79.16666666666667
communication 11, epoch 0: client-1 0.09780469698989323
communication 11, epoch 0: client-2 0.08903112635202053
Test set: Average loss: -8.314250176289567, Accuracy: 94.9074074074074
Poisoned set: Average loss: -10.441773542651424, Accuracy: 79.62962962962963
communication 12, epoch 0: client-1 0.0858488295986393
communication 12, epoch 0: client-2 0.08158361000467822
Test set: Average loss: -8.33344722242543, Accuracy: 94.9074074074074
Poisoned set: Average loss: -10.538744891131365, Accuracy: 80.55555555555556
communication 13, epoch 0: client-1 0.07867073030948264
communication 13, epoch 0: client-2 0.07543254211632118
Test set: Average loss: -8.353414006686458, Accuracy: 94.9074074074074
Poisoned set: Average loss: -10.62615246242947, Accuracy: 81.01851851851852
communication 14, epoch 0: client-1 0.07302954587543843
communication 14, epoch 0: client-2 0.07091031768792234
Test set: Average loss: -8.37564472095282, Accuracy: 94.9074074074074
Poisoned set: Average loss: -10.70359914832645, Accuracy: 81.48148148148148