import copy
import torch
import torch.nn as nn
from ..base_attack import BaseAttacker
from .utils.distance import cossim, l2
from .utils.regularization import (
    bn_regularizer,
    group_consistency,
    label_matching,
    total_variance,
)
from .utils.utils import _generate_fake_gradients, _setup_attack
[docs]class GradientInversion_Attack(BaseAttacker):
    """General Gradient Inversion Attacker
    model inversion attack based on gradients can be written as follows:
            x^* = argmin_x' L_grad(x': W, delta_W) + R_aux(x')
    , where X' is the reconstructed image.
    The attacker tries to find images whose gradients w.r.t the given model parameter W
    is similar to the gradients delta_W of the secret images.
    Attributes:
        target_model: a target torch module instance.
        x_shape: the input shape of target_model.
        y_shape: the output shape of target_model.
        optimize_label: If true, only optimize images (the label will be automatically estimated).
        pos_of_final_fc_layer: position of gradients corresponding to the final FC layer
                               within the gradients received from the client.
        num_iteration: number of iterations of optimization.
        optimizer_class: a class of torch optimizer for the attack.
        lossfunc: a function that takes the predictions of the target model and true labels
                  and returns the loss between them.
        distancefunc: a function which takes the gradients of reconstructed images and the client-side gradients
                      and returns the distance between them.
        tv_reg_coef: the coefficient of total-variance regularization.
        lm_reg_coef: the coefficient of label-matching regularization.
        l2_reg_coef: the coefficient of L2 regularization.
        bn_reg_coef: the coefficient of BN regularization.
        gc_reg_coef: the coefficient of group-consistency regularization.
        bn_reg_layers: a list of batch normalization layers of the target model.
        bn_reg_layer_inputs: a lit of extracted inputs of the specified bn layers
        custom_reg_func: a custom regularization function.
        custom_reg_coef: the coefficient of the custom regularization function
        device: device type.
        log_interval: the interval of logging.
        save_loss: If true, save the loss during the attack.
        seed: random state.
        group_num: the size of group,
        group_seed: a list of random states for each worker of the group
        early_stopping: early stopping
    """
    def __init__(
        self,
        target_model,
        x_shape,
        y_shape=None,
        optimize_label=True,
        gradient_ignore_pos=[],
        pos_of_final_fc_layer=-2,
        num_iteration=100,
        optimizer_class=torch.optim.LBFGS,
        optimizername=None,
        lossfunc=nn.CrossEntropyLoss(),
        distancefunc=l2,
        distancename=None,
        tv_reg_coef=0.0,
        lm_reg_coef=0.0,
        l2_reg_coef=0.0,
        bn_reg_coef=0.0,
        gc_reg_coef=0.0,
        bn_reg_layers=[],
        custom_reg_func=None,
        custom_reg_coef=0.0,
        custom_generate_fake_grad_fn=None,
        device="cpu",
        log_interval=10,
        save_loss=True,
        seed=0,
        group_num=5,
        group_seed=None,
        early_stopping=50,
        clamp_range=None,
        **kwargs,
    ):
        """Inits GradientInversion_Attack class.
        Args:
            target_model: a target torch module instance.
            x_shape: the input shape of target_model.
            y_shape: the output shape of target_model.
            optimize_label: If true, only optimize images (the label will be automatically estimated).
            gradient_ignore_pos: a list of positions whihc will be ignored during the calculation of
                                 the distance between gradients
            pos_of_final_fc_layer: position of gradients corresponding to the final FC layer
                                   within the gradients received from the client.
            num_iteration: number of iterations of optimization.
            optimizer_class: a class of torch optimizer for the attack.
            optimizername: a name of optimizer class (priority over optimizer_class).
            lossfunc: a function that takes the predictions of the target model and true labels
                    and returns the loss between them.
            distancefunc: a function which takes the gradients of reconstructed images and the client-side gradients
                        and returns the distance between them.
            distancename: a name of distancefunc (priority over distancefunc).
            tv_reg_coef: the coefficient of total-variance regularization.
            lm_reg_coef: the coefficient of label-matching regularization.
            l2_reg_coef: the coefficient of L2 regularization.
            bn_reg_coef: the coefficient of BN regularization.
            gc_reg_coef: the coefficient of group-consistency regularization.
            bn_reg_layers: a list of batch normalization layers of the target model.
            custom_reg_func: a custom regularization function.
            custom_reg_coef: the coefficient of the custom regularization function
            device: device type.
            log_interval: the interval of logging.
            save_loss: If true, save the loss during the attack.
            seed: random state.
            group_num: the size of group,
            group_seed: a list of random states for each worker of the group
            early_stopping: early stopping
            **kwargs: kwargs for the optimizer
        """
        super().__init__(target_model)
        self.x_shape = x_shape
        self.y_shape = (
            list(target_model.parameters())[-1].shape[0] if y_shape is None else y_shape
        )
        self.optimize_label = optimize_label
        self.gradient_ignore_pos = gradient_ignore_pos
        self.pos_of_final_fc_layer = pos_of_final_fc_layer
        self.num_iteration = num_iteration
        self.lossfunc = lossfunc
        self.distancefunc = distancefunc
        self._setup_distancefunc(distancename)
        self.optimizer_class = optimizer_class
        self._setup_optimizer_class(optimizername)
        self.tv_reg_coef = tv_reg_coef
        self.lm_reg_coef = lm_reg_coef
        self.l2_reg_coef = l2_reg_coef
        self.bn_reg_coef = bn_reg_coef
        self.gc_reg_coef = gc_reg_coef
        self.bn_reg_layers = bn_reg_layers
        self.bn_reg_layer_inputs = {}
        for i, bn_layer in enumerate(self.bn_reg_layers):
            bn_layer.register_forward_hook(self._get_hook_for_input(i))
        self.custom_reg_func = custom_reg_func
        self.custom_reg_coef = custom_reg_coef
        self.custom_generate_fake_grad_fn = custom_generate_fake_grad_fn
        self.device = device
        self.log_interval = log_interval
        self.save_loss = save_loss
        self.seed = seed
        self.group_num = group_num
        self.group_seed = list(range(group_num)) if group_seed is None else group_seed
        self.early_stopping = early_stopping
        self.clamp_range = clamp_range
        self.kwargs = kwargs
        torch.manual_seed(seed)
    def _setup_distancefunc(self, distancename):
        """Assigns a function to self.distancefunc according to distancename
        Args:
            distancename: name of the function to calculat the distance between the gradients.
                          currently support 'l2' or 'cossim'.
        Raises:
            ValueError: if distancename is not supported.
        """
        if distancename is None:
            return
        elif distancename == "l2":
            self.distancefunc = l2
        elif distancename == "cossim":
            self.distancefunc = cossim
        else:
            raise ValueError(f"{distancename} is not defined")
    def _setup_optimizer_class(self, optimizername):
        """Assigns a class to self.optimizer_class according to optimiername
        Args:
            optimizername: name of optimizer, currently support `LBFGS`, `SGD`, and `Adam`
        Raises:
            ValueError: if optimizername is not supported.
        """
        if optimizername is None:
            return
        elif optimizername == "LBFGS":
            self.optimizer_class = torch.optim.LBFGS
        elif optimizername == "SGD":
            self.optimizer_class = torch.optim.SGD
        elif optimizername == "Adam":
            self.optimizer_class = torch.optim.Adam
        else:
            raise ValueError(f"{optimizername} is not defined")
    def _get_hook_for_input(self, name):
        """Returns a hook function to extract the input of the specified layer of the target model
        Args:
            name: the key of self.bn_reg_layer_inputs for the target layer
        Returns:
            hook: a hook function
        """
        def hook(model, inp, output):
            self.bn_reg_layer_inputs[name] = inp[0]
        return hook
    def _calc_regularization_term(
        self, fake_x, fake_pred, fake_label, group_fake_x, received_gradients
    ):
        """calculates the regularization term
        Args:
            fake_x: reconstructed images
            fake_pred: the predicted value of reconstructed images
            faka_label: the labels of fake_x
            group_fake_x: a list of fake_x of each worker
            received_gradients: gradients received from the client
        Returns:
            calculated regularization term
        """
        reg_term = 0
        if self.tv_reg_coef != 0:
            reg_term += self.tv_reg_coef * total_variance(fake_x)
        if self.lm_reg_coef != 0:
            reg_term += self.lm_reg_coef * label_matching(fake_pred, fake_label)
        if self.l2_reg_coef != 0:
            reg_term += self.l2_reg_coef * torch.norm(fake_x, p=2)
        if self.bn_reg_coef != 0:
            reg_term += self.bn_reg_coef * bn_regularizer(
                self.bn_reg_layer_inputs, self.bn_reg_layers
            )
        if group_fake_x is not None and self.gc_reg_coef != 0:
            reg_term += self.gc_reg_coef * group_consistency(fake_x, group_fake_x)
        if self.custom_reg_func is not None and self.custom_reg_coef != 0:
            context = {
                "attacker": self,
                "fake_x": fake_x,
                "fake_label": fake_label,
                "received_gradients": received_gradients,
                "group_fake_x": group_fake_x,
            }
            reg_term += self.custom_reg_coef * self.custom_reg_func(context)
        return reg_term
    def _setup_closure(
        self, optimizer, fake_x, fake_label, received_gradients, group_fake_x=None
    ):
        """Returns a closure function for the optimizer
        Args:
            optimizer (torch.optim.Optimizer): an instance of the optimizer
            fake_x (torch.Tensor): reconstructed images
            fake_label (torch.Tensor): reconstructed or estimated labels
            received_gradients (list): a list of gradients received from the client
            group_fake_x (list, optional): a list of fake_x. Defaults to None.
        """
        def closure():
            optimizer.zero_grad()
            if self.custom_generate_fake_grad_fn is None:
                fake_pred, fake_gradients = _generate_fake_gradients(
                    self.target_model,
                    self.lossfunc,
                    self.optimize_label,
                    fake_x,
                    fake_label,
                )
            else:
                fake_pred, fake_gradients = self.custom_generate_fake_grad_fn(
                    self, fake_x, fake_label
                )
            distance = self.distancefunc(
                fake_gradients, received_gradients, self.gradient_ignore_pos
            )
            distance += self._calc_regularization_term(
                fake_x,
                fake_pred,
                fake_label,
                group_fake_x,
                received_gradients,
            )
            distance_val = distance.item()
            distance.backward(retain_graph=False)
            return distance_val
        return closure
[docs]    def reset_seed(self, seed):
        """Resets the random seed
        Args:
            seed (int): the random seed
        """
        self.seed = seed
        torch.manual_seed(seed) 
    def _update_logging(self, i, distance, best_iteration, best_distance):
        if self.save_loss:
            self.log_loss.append(distance)
        if self.log_interval != 0 and i % self.log_interval == 0:
            print(
                f"iter={i}: {distance}, (best_iter={best_iteration}: {best_distance})"
            )
[docs]    def attack(
        self,
        received_gradients,
        batch_size=1,
        init_x=None,
        labels=None,
    ):
        """Reconstructs the images from the gradients received from the client
        Args:
            received_gradients: the list of gradients received from the client.
            batch_size: batch size.
        Returns:
            a tuple of the best reconstructed images and corresponding labels
        Raises:
            OverflowError: If the calculated distance become Nan
        """
        fake_x, fake_label, optimizer = _setup_attack(
            self.x_shape,
            self.y_shape,
            self.optimizer_class,
            self.optimize_label,
            self.pos_of_final_fc_layer,
            self.device,
            received_gradients,
            batch_size,
            init_x=init_x,
            labels=labels,
            **self.kwargs,
        )
        # self._setup_attack(
        #    received_gradients, batch_size, init_x=init_x, labels=labels
        # )
        num_of_not_improve_round = 0
        best_distance = float("inf")
        self.log_loss = []
        for i in range(1, self.num_iteration + 1):
            closure = self._setup_closure(
                optimizer, fake_x, fake_label, received_gradients
            )
            distance = optimizer.step(closure)
            if self.clamp_range is not None:
                with torch.no_grad():
                    fake_x[:] = fake_x.clamp(self.clamp_range[0], self.clamp_range[1])
            # if torch.sum(torch.isnan(distance)).item():
            #    raise OverflowError("stop because the calculated distance is Nan")
            if best_distance > distance:
                best_fake_x = fake_x.detach().clone()
                best_fake_label = fake_label.detach().clone()
                best_distance = distance
                best_iteration = i
                num_of_not_improve_round = 0
            else:
                num_of_not_improve_round += 1
            self._update_logging(i, distance, best_iteration, best_distance)
            if num_of_not_improve_round > self.early_stopping:
                print(
                    f"iter={i}: loss did not improve in the last {self.early_stopping} rounds."
                )
                break
        return best_fake_x, best_fake_label 
[docs]    def group_attack(self, received_gradients, batch_size=1):
        """Multiple simultaneous attacks with different random states
        Args:
            received_gradients: the list of gradients received from the client.
            batch_size: batch size.
        Returns:
            a tuple of the best reconstructed images and corresponding labels
        """
        group_fake_x = []
        group_fake_label = []
        group_optimizer = []
        for _ in range(self.group_num):
            fake_x, fake_label, optimizer = _setup_attack(
                self.x_shape,
                self.y_shape,
                self.optimizer_class,
                self.optimize_label,
                self.pos_of_final_fc_layer,
                self.device,
                received_gradients,
                batch_size,
                **self.kwargs,
            )
            # self._setup_attack(
            #    received_gradients, batch_size
            # )
            group_fake_x.append(fake_x)
            group_fake_label.append(fake_label)
            group_optimizer.append(optimizer)
        best_distance = [float("inf") for _ in range(self.group_num)]
        best_fake_x = [x_.detach().clone() for x_ in group_fake_x]
        best_fake_label = [y_.detach().clone() for y_ in group_fake_label]
        best_iteration = [0 for _ in range(self.group_num)]
        self.log_loss = [[] for _ in range(self.group_num)]
        for i in range(1, self.num_iteration + 1):
            for worker_id in range(self.group_num):
                self.reset_seed(self.group_seed[worker_id])
                closure = self._setup_closure(
                    group_optimizer[worker_id],
                    group_fake_x[worker_id],
                    group_fake_label[worker_id],
                    received_gradients,
                )
                distance = group_optimizer[worker_id].step(closure)
                if self.save_loss:
                    self.log_loss[worker_id].append(distance)
                if best_distance[worker_id] > distance:
                    best_fake_x[worker_id] = group_fake_x[worker_id].detach().clone()
                    best_fake_label[worker_id] = (
                        group_fake_label[worker_id].detach().clone()
                    )
                    best_distance[worker_id] = distance
                    best_iteration[worker_id] = i
                if self.log_interval != 0 and i % self.log_interval == 0:
                    print(
                        f"worker_id={worker_id}: iter={i}: {distance}, (best_iter={best_iteration[worker_id]}: {best_distance[worker_id]})"
                    )
        return best_fake_x, best_fake_label