Source code for aijack.attack.inversion.gan_attack

import copy

import torch

from ...manager import BaseManager


[docs]def attach_ganattack_to_client( cls, target_label, generator, generator_optimizer, generator_criterion, nz=100, device="cpu", gan_batch_size=1, gan_epoch=1, gan_log_interval=0, ignore_first_download=False, ): """Wraps the given class in GANAttackClientWrapper. Args: target_label(int): index of target class generator (torch.nn.Module): Generator generator_optimizer (torch.optim.Optimizer): optimizer for the generator generator_criterion (function): loss function for the generator nz (int): dimension of latent space of the generator. Defaults to 100. device (str, optional): _description_. Defaults to "cpu". gan_batch_size (int, optional): batch size for training GAN. Defaults to 1. gan_epoch (int, optional): epoch for training GAN. Defaults to 1. gan_log_interval (int, optional): log interval. Defaults to 0. ignore_first_download (bool, optional): Defaults to False. Returns: cls: a class wrapped in GANAttackClientWrapper """ class GANAttackClientWrapper(cls): """Implementation of GAN based model inversion attack (https://arxiv.org/abs/1702.07464)""" def __init__(self, *args, **kwargs): super(GANAttackClientWrapper, self).__init__(*args, **kwargs) self.target_label = target_label self.generator = generator self.generator_optimizer = generator_optimizer self.generator_criterion = generator_criterion self.nz = nz self.device = device self.discriminator = copy.deepcopy(self.model) self.discriminator.to(self.device) self.noise = torch.randn(1, self.nz, 1, 1, device=self.device) self.is_params_initialized = False def update_generator(self, batch_size=10, epoch=1, log_interval=5): """Updates the Generator Args: batch_size (int): batch size epoch (int): epoch log_interval (int): interval of logging """ for i in range(1, epoch + 1): running_error = 0 self.generator.zero_grad() noise = torch.randn(batch_size, self.nz, 1, 1, device=self.device) fake = self.generator(noise) output = self.discriminator(fake) label = torch.full( (batch_size,), self.target_label, dtype=torch.int64, device=self.device, ) loss_generator = self.generator_criterion(output, label) loss_generator.backward() self.generator_optimizer.step() running_error += loss_generator.item() if log_interval != 0 and i % log_interval == 0: print( f"updating generator - epoch {i}: generator loss is {running_error/batch_size}" ) def update_discriminator(self): """Updates the discriminator""" self.discriminator.load_state_dict(self.model.state_dict()) def download(self, model_parameters): """Downloads the new model""" super().download(model_parameters) if ignore_first_download and not self.is_params_initialized: self.is_params_initialized = True return self.update_discriminator() self.update_generator( batch_size=gan_batch_size, epoch=gan_epoch, log_interval=gan_log_interval, ) def attack(self, n): """Generates fake images Args: n (int): the number of fake images created by the Generator Returns: fake: generated fake images """ noise = torch.randn(n, self.nz, 1, 1, device=self.device) with torch.no_grad(): fake = self.generator(noise) return fake return GANAttackClientWrapper
[docs]class GANAttackClientManager(BaseManager): """Manager class for GAN based model inversion attack (https://arxiv.org/abs/1702.07464)""" def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs
[docs] def attach(self, cls): """Wraps the given class in GANAttackClientWrapper. Returns: cls: a class wrapped in GANAttackClientWrapper """ return attach_ganattack_to_client(cls, *self.args, **self.kwargs)