2.2. Gradient-based Model Inversion Attack against Federated Learning#

In this tutorial, we will experiment gradient-based model inversion attack, which allows the malicious server of Federated Learning to reconstruct the private local dataset via shared gradients. You can implement five popular gradient-based model inversion attacks with AIJack. These methods reconstruct the private images by minimizing the distance between the fake gradients and the received gradients. Each method has its own strategy, such as the distance metric and regularization terms.

One example is as follows. Since the server already knows the parameters of the global model $w_{t - 1}$, the server can estimate the private training sample $(X, Y)$ with the following optimization.

$$ X’ \leftarrow X’ - \lambda \nabla_{X’} D $$

$$ Y’ \leftarrow Y’ - \lambda \nabla_{Y’} D $$

, where $D$ is the loss function calculated as follows:

$$ D = || \nabla \mathcal{l}(w_{t - 1}, X, Y) - \nabla \mathcal{l}(w_{t - 1}, X’, Y’) ||_{2} $$

In other words, this attack tries to reconstruct the private training data by optimizing the fake data to generate gradients close enough to the received gradients from the client.

import cv2
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from numpy import e
from matplotlib import pyplot as plt
import torch.optim as optim
from tqdm.notebook import tqdm

from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer
from aijack.attack.inversion import GradientInversionAttackServerManager
from torch.utils.data import DataLoader, TensorDataset
from aijack.utils import NumpyDataset

import warnings

warnings.filterwarnings("ignore")
class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            nn.BatchNorm2d(12),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            nn.BatchNorm2d(12),
            act(),
        )
        self.fc = nn.Sequential(nn.Linear(hideen, num_classes))

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
def prepare_dataloader(path="MNIST/.", batch_size=64, shuffle=True):
    at_t_dataset_train = torchvision.datasets.MNIST(
        root=path, train=True, download=True
    )

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    dataset = NumpyDataset(
        at_t_dataset_train.train_data.numpy(),
        at_t_dataset_train.train_labels.numpy(),
        transform=transform,
    )

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0
    )
    return dataloader
torch.manual_seed(7777)

shape_img = (28, 28)
num_classes = 10
channel = 1
hidden = 588

num_seeds = 5

2.2.1. Reconstruct Single Data#

First, we try to recover the following private image from the received gradients with the batch size of 1.

device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
dataloader = prepare_dataloader()
for data in dataloader:
    xs, ys = data[0], data[1]
    break

x = xs[:1]
y = ys[:1]

fig = plt.figure(figsize=(1, 1))
plt.axis("off")
plt.imshow(x.detach().numpy()[0][0], cmap="gray")
plt.show()
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/./MNIST/raw/train-images-idx3-ubyte.gz
Extracting MNIST/./MNIST/raw/train-images-idx3-ubyte.gz to MNIST/./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting MNIST/./MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/./MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting MNIST/./MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/./MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting MNIST/./MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/./MNIST/raw
../_images/3b8c31d1e16c3d14eead588d5f05a090e181d50e80bb49a077d2b1692916c7a9.png
criterion = nn.CrossEntropyLoss()

2.2.1.1. DLG#

You can convert the normal server to the malicious attacker with GradientInversionAttackServerManager.

https://dlg.mit.edu/assets/NeurIPS19_deep_leakage_from_gradients.pdf

  • distance metric = L2 norm

  • optimize labels

manager = GradientInversionAttackServerManager(
    (1, 28, 28),
    num_trial_per_communication=5,
    log_interval=0,
    num_iteration=100,
    distancename="l2",
    device=device,
    lr=1.0,
)
DLGFedAVGServer = manager.attach(FedAVGServer)

client = FedAVGClient(
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)
server = DLGFedAVGServer(
    [client],
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)

local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]

api = FedAVGAPI(
    server,
    [client],
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=1,
    local_epoch=1,
    use_gradients=True,
    device=device,
)

api.run()

fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
    ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
    ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
communication 0, epoch 0: client-1 2.285383462905884
iter=80: loss did not improve in the last 50 rounds.
iter=73: loss did not improve in the last 50 rounds.
iter=70: loss did not improve in the last 50 rounds.
../_images/6c0381ebbb798ecd68328902774ed4b3c843326c23435b9899f2c28ba77d6450.png

2.2.1.2. GS Attack#

https://arxiv.org/abs/2003.14053

  • distance metric = cosine similarity

  • optimize labels

  • regularization: total-variance

manager = GradientInversionAttackServerManager(
    (1, 28, 28),
    num_trial_per_communication=5,
    log_interval=0,
    num_iteration=100,
    tv_reg_coef=0.01,
    distancename="cossim",
    device=device,
    lr=1.0,
)
GSFedAVGServer = manager.attach(FedAVGServer)

client = FedAVGClient(
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)
server = GSFedAVGServer(
    [client],
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)

local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]

api = FedAVGAPI(
    server,
    [client],
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=1,
    local_epoch=1,
    use_gradients=True,
    device=device,
)

api.run()

fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
    ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
    ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
communication 0, epoch 0: client-1 2.371312141418457
iter=89: loss did not improve in the last 50 rounds.
iter=72: loss did not improve in the last 50 rounds.
iter=71: loss did not improve in the last 50 rounds.
../_images/486d036021ab5be6949b5bab7775fceceb99d151107d088f0b1893d1e64a9946.png

2.2.1.3. iDLG Attack#

https://arxiv.org/abs/2001.02610

  • distance metric = L2 norm

  • analytically estimate a label from the gradients

manager = GradientInversionAttackServerManager(
    (1, 28, 28),
    num_trial_per_communication=5,
    log_interval=0,
    num_iteration=1000,
    optimizer_class=torch.optim.SGD,
    distancename="l2",
    optimize_label=False,
    device=device,
    lr=1.0,
)
iDLGFedAVGServer = manager.attach(FedAVGServer)

client = FedAVGClient(
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)
server = iDLGFedAVGServer(
    [client],
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
    lr=1.0,
    device=device,
)

local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]

api = FedAVGAPI(
    server,
    [client],
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=1,
    local_epoch=1,
    use_gradients=True,
    device=device,
)

api.run()

fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
    ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
    ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
communication 0, epoch 0: client-1 2.371312141418457
../_images/20cb1a06a92d02bf37aad2b659f22313a5a1171f92453a4aa180d387802f7898.png

2.2.1.4. CPL Attack#

https://arxiv.org/abs/2004.10397

  • distance metric = L2 norm

  • analytically estimate an label from the gradients

  • regularization: label-matching

manager = GradientInversionAttackServerManager(
    (1, 28, 28),
    num_trial_per_communication=5,
    log_interval=0,
    num_iteration=1000,
    optimizer_class=torch.optim.SGD,
    distancename="l2",
    optimize_label=False,
    lm_reg_coef=0.01,
    lr=1.0,
)
CPLFedAVGServer = manager.attach(FedAVGServer)

client = FedAVGClient(
    LeNet(channel=channel, hideen=hidden, num_classes=num_classes), lr=1.0
)
server = CPLFedAVGServer(
    [client], LeNet(channel=channel, hideen=hidden, num_classes=num_classes), lr=1.0
)

local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]

api = FedAVGAPI(
    server,
    [client],
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=1,
    local_epoch=1,
    use_gradients=True,
)

api.run()

fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
    ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
    ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
communication 0, epoch 0: client-1 2.371312141418457
../_images/d6e194439b5e1dcb44ca04343d51864a0da2feda29ec4923fb50cc88b535510e.png

2.2.2. Reconstruct Batched Data#

Second, we simulate the situatino with larger batch size. We try to recover the bellow three images.

batch_size = 3
x_batch = xs[:batch_size]
y_batch = ys[:batch_size]

fig = plt.figure(figsize=(3, 2))
for bi in range(batch_size):
    ax = fig.add_subplot(1, batch_size, bi + 1)
    ax.imshow(x_batch[bi].detach().numpy()[0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
../_images/4e522c3dcafa43b6f9821d0ae1cfba94643c7546348b08ad752a4370a06a3b8d.png

2.2.2.1. GradInversion#

https://arxiv.org/abs/2104.07586

  • distance metric = L2 norm

  • analytically estimate labels from the gradients

  • reguralization: total-variance, l2, bn, and group-consistency

from aijack.attack.inversion import GradientInversion_Attack

net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes)
pred = net(x_batch)
loss = criterion(pred, y_batch)
received_gradients = torch.autograd.grad(loss, net.parameters())
received_gradients = [cg.detach() for cg in received_gradients]

gradinversion = GradientInversion_Attack(
    net,
    (1, 28, 28),
    num_iteration=1000,
    lr=1e2,
    log_interval=0,
    optimizer_class=torch.optim.SGD,
    distancename="l2",
    optimize_label=False,
    bn_reg_layers=[net.body[1], net.body[4], net.body[7]],
    group_num=3,
    tv_reg_coef=0.00,
    l2_reg_coef=0.0001,
    bn_reg_coef=0.001,
    gc_reg_coef=0.001,
)

result = gradinversion.group_attack(received_gradients, batch_size=batch_size)

fig = plt.figure(figsize=(3, 2))
for bid in range(batch_size):
    ax1 = fig.add_subplot(1, batch_size, bid + 1)
    ax1.imshow((sum(result[0]) / len(result[0])).detach().numpy()[bid][0], cmap="gray")
    ax1.axis("off")
plt.tight_layout()
plt.show()
../_images/cd3e93f074e738fa7f06b567f0fe1cd2f925b1879277caaf87b5ed8537324cc6.png