2.5. Mutual Information-based Defense#
import torch
from torch import nn
from torch import optim
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, Dataset, DataLoader
from tqdm.notebook import tqdm
import numpy as np
from aijack.defense import VIB, KL_between_normals, mib_loss
dim_z = 256
beta = 1e-3
batch_size = 100
samples_amount = 15
num_epochs = 1
train_data = MNIST("MNIST/.", download=True, train=True)
train_dataset = TensorDataset(
train_data.train_data.view(-1, 28 * 28).float() / 255, train_data.train_labels
)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_data = MNIST("MNIST/.", download=True, train=False)
test_dataset = TensorDataset(
test_data.test_data.view(-1, 28 * 28).float() / 255, test_data.test_labels
)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
encoder = nn.Sequential(
nn.Linear(in_features=784, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=2 * dim_z),
)
decoder = nn.Linear(in_features=dim_z, out_features=10)
net = VIB(encoder, decoder, dim_z, num_samples=samples_amount)
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.97)
import time
for epoch in range(num_epochs):
loss_by_epoch = []
accuracy_by_epoch = []
I_ZX_bound_by_epoch = []
I_ZY_bound_by_epoch = []
loss_by_epoch_test = []
accuracy_by_epoch_test = []
I_ZX_bound_by_epoch_test = []
I_ZY_bound_by_epoch_test = []
if epoch % 2 == 0 and epoch > 0:
scheduler.step()
for x_batch, y_batch in tqdm(train_loader):
x_batch = x_batch
y_batch = y_batch
y_pred, result_dict = net(x_batch)
sampled_y_pred = result_dict["sampled_decoded_outputs"]
p_z_given_x_mu = result_dict["p_z_given_x_mu"]
p_z_given_x_sigma = result_dict["p_z_given_x_sigma"]
approximated_z_mean = torch.zeros_like(p_z_given_x_mu)
approximated_z_sigma = torch.ones_like(p_z_given_x_sigma)
loss, I_ZY_bound, I_ZX_bound = mib_loss(
y_batch,
sampled_y_pred,
p_z_given_x_mu,
p_z_given_x_sigma,
approximated_z_mean,
approximated_z_sigma,
beta=beta,
)
prediction = torch.max(y_pred, dim=1)[1]
accuracy = torch.mean((prediction == y_batch).float())
loss.backward()
opt.step()
opt.zero_grad()
I_ZX_bound_by_epoch.append(I_ZX_bound.item())
I_ZY_bound_by_epoch.append(I_ZY_bound.item())
loss_by_epoch.append(loss.item())
accuracy_by_epoch.append(accuracy.item())
for x_batch, y_batch in tqdm(test_loader):
x_batch = x_batch
y_batch = y_batch
y_pred, result_dict = net(x_batch)
sampled_y_pred = result_dict["sampled_decoded_outputs"]
p_z_given_x_mu = result_dict["p_z_given_x_mu"]
p_z_given_x_sigma = result_dict["p_z_given_x_sigma"]
approximated_z_mean = torch.zeros_like(p_z_given_x_mu)
approximated_z_sigma = torch.ones_like(p_z_given_x_sigma)
loss, I_ZY_bound, I_ZX_bound = mib_loss(
y_batch,
sampled_y_pred,
p_z_given_x_mu,
p_z_given_x_sigma,
approximated_z_mean,
approximated_z_sigma,
beta=beta,
)
prediction = torch.max(y_pred, dim=1)[1]
accuracy = torch.mean((prediction == y_batch).float())
I_ZX_bound_by_epoch_test.append(I_ZX_bound.item())
I_ZY_bound_by_epoch_test.append(I_ZY_bound.item())
loss_by_epoch_test.append(loss.item())
accuracy_by_epoch_test.append(accuracy.item())
print(
"epoch",
epoch,
"loss",
np.mean(loss_by_epoch_test),
"prediction",
np.mean(accuracy_by_epoch_test),
)
print(
"I_ZX_bound",
np.mean(I_ZX_bound_by_epoch_test),
"I_ZY_bound",
np.mean(I_ZY_bound_by_epoch_test),
)
from aijack.attack import GradientInversion_Attack
y_pred, result_dict = net(x_batch[:1])
sampled_y_pred = result_dict["sampled_decoded_outputs"]
p_z_given_x_mu = result_dict["p_z_given_x_mu"]
p_z_given_x_sigma = result_dict["p_z_given_x_sigma"]
approximated_z_mean = torch.zeros_like(p_z_given_x_mu)
approximated_z_sigma = torch.ones_like(p_z_given_x_sigma)
loss, I_ZY_bound, I_ZX_bound = mib_loss(
y_batch[:1],
sampled_y_pred,
p_z_given_x_mu,
p_z_given_x_sigma,
approximated_z_mean,
approximated_z_sigma,
beta=beta,
)
received_gradients = torch.autograd.grad(loss, net.parameters())
received_gradients = [cg.detach() for cg in received_gradients]
received_gradients = [cg for cg in received_gradients]
from matplotlib import pyplot as plt
import cv2
net.eval()
cpl_attacker = GradientInversion_Attack(
net,
(784,),
lr=0.3,
log_interval=50,
optimizer_class=torch.optim.LBFGS,
distancename="l2",
optimize_label=False,
num_iteration=200,
)
num_seeds = 5
fig = plt.figure(figsize=(6, 2))
for s in tqdm(range(num_seeds)):
cpl_attacker.reset_seed(s)
try:
result = cpl_attacker.attack(received_gradients)
ax1 = fig.add_subplot(2, num_seeds, s + 1)
ax1.imshow(result[0].cpu().detach().numpy()[0].reshape(28, 28), cmap="gray")
ax1.axis("off")
ax1.set_title(torch.argmax(result[1]).cpu().item())
ax2 = fig.add_subplot(2, num_seeds, num_seeds + s + 1)
ax2.imshow(
cv2.medianBlur(result[0].cpu().detach().numpy()[0].reshape(28, 28), 5),
cmap="gray",
)
ax2.axis("off")
except:
pass
plt.suptitle("Result of CPL")
plt.tight_layout()
plt.show()