Source code for aijack.attack.inversion.mi_face
import torch
from matplotlib import pyplot as plt
from ..base_attack import BaseAttacker
[docs]class MI_FACE(BaseAttacker):
"""Implementation of model inversion attack
reference: https://dl.acm.org/doi/pdf/10.1145/2810103.2813677
Attributes:
target_model: model of the victim
input_shape: input shapes of taregt model
auxterm_func (function): the default is constant function
process_func (function): the default is identity function
"""
def __init__(
self,
target_model,
input_shape=(1, 1, 64, 64),
target_label=0,
lam=0.01,
num_itr=100,
auxterm_func=lambda x: 0,
process_func=lambda x: x,
apply_softmax=False,
device="cpu",
log_interval=1,
log_show_img=False,
show_img_func=lambda x: x * 0.5 + 0.5,
):
"""Inits MI_FACE
Args:
target_model: model of the victim
input_shape: input shapes of taregt model
auxterm_func (function): the default is constant function
process_func (function): the default is identity function
"""
super().__init__(target_model)
self.input_shape = input_shape
self.target_label = target_label
self.lam = lam
self.num_itr = num_itr
self.auxterm_func = auxterm_func
self.process_func = process_func
self.device = device
self.log_interval = log_interval
self.log_show_img = log_show_img
self.apply_softmax = apply_softmax
self.show_img_func = show_img_func
self.log_image = []
[docs] def attack(
self,
init_x=None,
):
"""Execute the model inversion attack on the target model.
Args:
target_label (int): taregt label
lam (float): step size
num_itr (int): number of iteration
Returns:
best_img: inversed image with the best score
log :
"""
log = []
if init_x is None:
x = torch.zeros(self.input_shape, requires_grad=True).to(self.device)
else:
init_x = init_x.to(self.device)
x = init_x
best_score = float("inf")
best_img = None
for i in range(self.num_itr):
x = x.detach()
x.requires_grad = True
pred = self.target_model(x)
pred = pred.softmax(dim=1) if self.apply_softmax else pred
target_pred = pred[:, [self.target_label]]
c = 1 - target_pred + self.auxterm_func(x)
c.backward()
grad = x.grad
if c.item() < best_score:
best_img = x
with torch.no_grad():
x -= self.lam * grad
x = self.process_func(x)
log.append(c.item())
if self.log_interval != 0 and i % self.log_interval == 0:
print(f"epoch {i}: {c.item()}")
self._show_img(x)
self.log_image.append(x.clone())
self._show_img(x)
return best_img, log
def _show_img(self, x):
if self.log_show_img:
if self.input_shape[1] == 1:
plt.imshow(
self.show_img_func(x.detach().cpu().numpy()[0][0]),
cmap="gray",
)
plt.show()
else:
plt.imshow(
self.show_img_func(x.detach().cpu().numpy()[0].transpose(1, 2, 0))
)
plt.show()