Source code for aijack.attack.inversion.gradientinversion

import copy

import torch
import torch.nn as nn

from ..base_attack import BaseAttacker
from .utils.distance import cossim, l2
from .utils.regularization import (
from .utils.utils import _generate_fake_gradients, _setup_attack

[docs]class GradientInversion_Attack(BaseAttacker): """General Gradient Inversion Attacker model inversion attack based on gradients can be written as follows: x^* = argmin_x' L_grad(x': W, delta_W) + R_aux(x') , where X' is the reconstructed image. The attacker tries to find images whose gradients w.r.t the given model parameter W is similar to the gradients delta_W of the secret images. Attributes: target_model: a target torch module instance. x_shape: the input shape of target_model. y_shape: the output shape of target_model. optimize_label: If true, only optimize images (the label will be automatically estimated). pos_of_final_fc_layer: position of gradients corresponding to the final FC layer within the gradients received from the client. num_iteration: number of iterations of optimization. optimizer_class: a class of torch optimizer for the attack. lossfunc: a function that takes the predictions of the target model and true labels and returns the loss between them. distancefunc: a function which takes the gradients of reconstructed images and the client-side gradients and returns the distance between them. tv_reg_coef: the coefficient of total-variance regularization. lm_reg_coef: the coefficient of label-matching regularization. l2_reg_coef: the coefficient of L2 regularization. bn_reg_coef: the coefficient of BN regularization. gc_reg_coef: the coefficient of group-consistency regularization. bn_reg_layers: a list of batch normalization layers of the target model. bn_reg_layer_inputs: a lit of extracted inputs of the specified bn layers custom_reg_func: a custom regularization function. custom_reg_coef: the coefficient of the custom regularization function device: device type. log_interval: the interval of logging. save_loss: If true, save the loss during the attack. seed: random state. group_num: the size of group, group_seed: a list of random states for each worker of the group early_stopping: early stopping """ def __init__( self, target_model, x_shape, y_shape=None, optimize_label=True, gradient_ignore_pos=[], pos_of_final_fc_layer=-2, num_iteration=100, optimizer_class=torch.optim.LBFGS, optimizername=None, lossfunc=nn.CrossEntropyLoss(), distancefunc=l2, distancename=None, tv_reg_coef=0.0, lm_reg_coef=0.0, l2_reg_coef=0.0, bn_reg_coef=0.0, gc_reg_coef=0.0, bn_reg_layers=[], custom_reg_func=None, custom_reg_coef=0.0, custom_generate_fake_grad_fn=None, device="cpu", log_interval=10, save_loss=True, seed=0, group_num=5, group_seed=None, early_stopping=50, clamp_range=None, **kwargs, ): """Inits GradientInversion_Attack class. Args: target_model: a target torch module instance. x_shape: the input shape of target_model. y_shape: the output shape of target_model. optimize_label: If true, only optimize images (the label will be automatically estimated). gradient_ignore_pos: a list of positions whihc will be ignored during the calculation of the distance between gradients pos_of_final_fc_layer: position of gradients corresponding to the final FC layer within the gradients received from the client. num_iteration: number of iterations of optimization. optimizer_class: a class of torch optimizer for the attack. optimizername: a name of optimizer class (priority over optimizer_class). lossfunc: a function that takes the predictions of the target model and true labels and returns the loss between them. distancefunc: a function which takes the gradients of reconstructed images and the client-side gradients and returns the distance between them. distancename: a name of distancefunc (priority over distancefunc). tv_reg_coef: the coefficient of total-variance regularization. lm_reg_coef: the coefficient of label-matching regularization. l2_reg_coef: the coefficient of L2 regularization. bn_reg_coef: the coefficient of BN regularization. gc_reg_coef: the coefficient of group-consistency regularization. bn_reg_layers: a list of batch normalization layers of the target model. custom_reg_func: a custom regularization function. custom_reg_coef: the coefficient of the custom regularization function device: device type. log_interval: the interval of logging. save_loss: If true, save the loss during the attack. seed: random state. group_num: the size of group, group_seed: a list of random states for each worker of the group early_stopping: early stopping **kwargs: kwargs for the optimizer """ super().__init__(target_model) self.x_shape = x_shape self.y_shape = ( list(target_model.parameters())[-1].shape[0] if y_shape is None else y_shape ) self.optimize_label = optimize_label self.gradient_ignore_pos = gradient_ignore_pos self.pos_of_final_fc_layer = pos_of_final_fc_layer self.num_iteration = num_iteration self.lossfunc = lossfunc self.distancefunc = distancefunc self._setup_distancefunc(distancename) self.optimizer_class = optimizer_class self._setup_optimizer_class(optimizername) self.tv_reg_coef = tv_reg_coef self.lm_reg_coef = lm_reg_coef self.l2_reg_coef = l2_reg_coef self.bn_reg_coef = bn_reg_coef self.gc_reg_coef = gc_reg_coef self.bn_reg_layers = bn_reg_layers self.bn_reg_layer_inputs = {} for i, bn_layer in enumerate(self.bn_reg_layers): bn_layer.register_forward_hook(self._get_hook_for_input(i)) self.custom_reg_func = custom_reg_func self.custom_reg_coef = custom_reg_coef self.custom_generate_fake_grad_fn = custom_generate_fake_grad_fn self.device = device self.log_interval = log_interval self.save_loss = save_loss self.seed = seed self.group_num = group_num self.group_seed = list(range(group_num)) if group_seed is None else group_seed self.early_stopping = early_stopping self.clamp_range = clamp_range self.kwargs = kwargs torch.manual_seed(seed) def _setup_distancefunc(self, distancename): """Assigns a function to self.distancefunc according to distancename Args: distancename: name of the function to calculat the distance between the gradients. currently support 'l2' or 'cossim'. Raises: ValueError: if distancename is not supported. """ if distancename is None: return elif distancename == "l2": self.distancefunc = l2 elif distancename == "cossim": self.distancefunc = cossim else: raise ValueError(f"{distancename} is not defined") def _setup_optimizer_class(self, optimizername): """Assigns a class to self.optimizer_class according to optimiername Args: optimizername: name of optimizer, currently support `LBFGS`, `SGD`, and `Adam` Raises: ValueError: if optimizername is not supported. """ if optimizername is None: return elif optimizername == "LBFGS": self.optimizer_class = torch.optim.LBFGS elif optimizername == "SGD": self.optimizer_class = torch.optim.SGD elif optimizername == "Adam": self.optimizer_class = torch.optim.Adam else: raise ValueError(f"{optimizername} is not defined") def _get_hook_for_input(self, name): """Returns a hook function to extract the input of the specified layer of the target model Args: name: the key of self.bn_reg_layer_inputs for the target layer Returns: hook: a hook function """ def hook(model, inp, output): self.bn_reg_layer_inputs[name] = inp[0] return hook def _calc_regularization_term( self, fake_x, fake_pred, fake_label, group_fake_x, received_gradients ): """calculates the regularization term Args: fake_x: reconstructed images fake_pred: the predicted value of reconstructed images faka_label: the labels of fake_x group_fake_x: a list of fake_x of each worker received_gradients: gradients received from the client Returns: calculated regularization term """ reg_term = 0 if self.tv_reg_coef != 0: reg_term += self.tv_reg_coef * total_variance(fake_x) if self.lm_reg_coef != 0: reg_term += self.lm_reg_coef * label_matching(fake_pred, fake_label) if self.l2_reg_coef != 0: reg_term += self.l2_reg_coef * torch.norm(fake_x, p=2) if self.bn_reg_coef != 0: reg_term += self.bn_reg_coef * bn_regularizer( self.bn_reg_layer_inputs, self.bn_reg_layers ) if group_fake_x is not None and self.gc_reg_coef != 0: reg_term += self.gc_reg_coef * group_consistency(fake_x, group_fake_x) if self.custom_reg_func is not None and self.custom_reg_coef != 0: context = { "attacker": self, "fake_x": fake_x, "fake_label": fake_label, "received_gradients": received_gradients, "group_fake_x": group_fake_x, } reg_term += self.custom_reg_coef * self.custom_reg_func(context) return reg_term def _setup_closure( self, optimizer, fake_x, fake_label, received_gradients, group_fake_x=None ): """Returns a closure function for the optimizer Args: optimizer (torch.optim.Optimizer): an instance of the optimizer fake_x (torch.Tensor): reconstructed images fake_label (torch.Tensor): reconstructed or estimated labels received_gradients (list): a list of gradients received from the client group_fake_x (list, optional): a list of fake_x. Defaults to None. """ def closure(): optimizer.zero_grad() if self.custom_generate_fake_grad_fn is None: fake_pred, fake_gradients = _generate_fake_gradients( self.target_model, self.lossfunc, self.optimize_label, fake_x, fake_label, ) else: fake_pred, fake_gradients = self.custom_generate_fake_grad_fn( self, fake_x, fake_label ) distance = self.distancefunc( fake_gradients, received_gradients, self.gradient_ignore_pos ) distance += self._calc_regularization_term( fake_x, fake_pred, fake_label, group_fake_x, received_gradients, ) distance_val = distance.item() distance.backward(retain_graph=False) return distance_val return closure
[docs] def reset_seed(self, seed): """Resets the random seed Args: seed (int): the random seed """ self.seed = seed torch.manual_seed(seed)
def _update_logging(self, i, distance, best_iteration, best_distance): if self.save_loss: self.log_loss.append(distance) if self.log_interval != 0 and i % self.log_interval == 0: print( f"iter={i}: {distance}, (best_iter={best_iteration}: {best_distance})" )
[docs] def attack( self, received_gradients, batch_size=1, init_x=None, labels=None, ): """Reconstructs the images from the gradients received from the client Args: received_gradients: the list of gradients received from the client. batch_size: batch size. Returns: a tuple of the best reconstructed images and corresponding labels Raises: OverflowError: If the calculated distance become Nan """ fake_x, fake_label, optimizer = _setup_attack( self.x_shape, self.y_shape, self.optimizer_class, self.optimize_label, self.pos_of_final_fc_layer, self.device, received_gradients, batch_size, init_x=init_x, labels=labels, **self.kwargs, ) # self._setup_attack( # received_gradients, batch_size, init_x=init_x, labels=labels # ) num_of_not_improve_round = 0 best_distance = float("inf") self.log_loss = [] for i in range(1, self.num_iteration + 1): closure = self._setup_closure( optimizer, fake_x, fake_label, received_gradients ) distance = optimizer.step(closure) if self.clamp_range is not None: with torch.no_grad(): fake_x[:] = fake_x.clamp(self.clamp_range[0], self.clamp_range[1]) # if torch.sum(torch.isnan(distance)).item(): # raise OverflowError("stop because the calculated distance is Nan") if best_distance > distance: best_fake_x = fake_x.detach().clone() best_fake_label = fake_label.detach().clone() best_distance = distance best_iteration = i num_of_not_improve_round = 0 else: num_of_not_improve_round += 1 self._update_logging(i, distance, best_iteration, best_distance) if num_of_not_improve_round > self.early_stopping: print( f"iter={i}: loss did not improve in the last {self.early_stopping} rounds." ) break return best_fake_x, best_fake_label
[docs] def group_attack(self, received_gradients, batch_size=1): """Multiple simultaneous attacks with different random states Args: received_gradients: the list of gradients received from the client. batch_size: batch size. Returns: a tuple of the best reconstructed images and corresponding labels """ group_fake_x = [] group_fake_label = [] group_optimizer = [] for _ in range(self.group_num): fake_x, fake_label, optimizer = _setup_attack( self.x_shape, self.y_shape, self.optimizer_class, self.optimize_label, self.pos_of_final_fc_layer, self.device, received_gradients, batch_size, **self.kwargs, ) # self._setup_attack( # received_gradients, batch_size # ) group_fake_x.append(fake_x) group_fake_label.append(fake_label) group_optimizer.append(optimizer) best_distance = [float("inf") for _ in range(self.group_num)] best_fake_x = [x_.detach().clone() for x_ in group_fake_x] best_fake_label = [y_.detach().clone() for y_ in group_fake_label] best_iteration = [0 for _ in range(self.group_num)] self.log_loss = [[] for _ in range(self.group_num)] for i in range(1, self.num_iteration + 1): for worker_id in range(self.group_num): self.reset_seed(self.group_seed[worker_id]) closure = self._setup_closure( group_optimizer[worker_id], group_fake_x[worker_id], group_fake_label[worker_id], received_gradients, ) distance = group_optimizer[worker_id].step(closure) if self.save_loss: self.log_loss[worker_id].append(distance) if best_distance[worker_id] > distance: best_fake_x[worker_id] = group_fake_x[worker_id].detach().clone() best_fake_label[worker_id] = ( group_fake_label[worker_id].detach().clone() ) best_distance[worker_id] = distance best_iteration[worker_id] = i if self.log_interval != 0 and i % self.log_interval == 0: print( f"worker_id={worker_id}: iter={i}: {distance}, (best_iter={best_iteration[worker_id]}: {best_distance[worker_id]})" ) return best_fake_x, best_fake_label