Source code for aijack.attack.inversion.generator_attack

from typing import List

from ..base_attack import BaseAttacker


[docs]class Generator_Attack(BaseAttacker): def __init__( self, target_model, attacker_model, attacker_optimizer, log_interval=1, early_stopping=5, device="cpu", ): if type(target_model) == List: super().__init__(target_model=target_model[0]) else: super().__init__(target_model=target_model) self.attacker_model = attacker_model self.attacker_optimizer = attacker_optimizer self.log_interval = log_interval self.early_stopping = early_stopping self.device = device self.target_model_list = ( target_model if type(target_model) == list else [target_model] )
[docs] def calc_loss(self, dataloader, x_pos=0, y_pos=1, arbitrary_y=False): running_loss = 0 for data in dataloader: x = data[x_pos] x = x.to(self.device) loss = 0 for target_model in self.target_model_list: target_outputs = data[y_pos] if arbitrary_y else target_model(x) target_outputs = target_outputs.to(self.device) attack_outputs = self.attacker_model(target_outputs) loss = loss + ((x - attack_outputs) ** 2).mean() running_loss = running_loss + loss / len(dataloader) return running_loss
[docs] def fit(self, dataloader, epoch, x_pos=0, y_pos=1, arbitrary_y=False): best_loss = float("inf") best_epoch = 0 for i in range(epoch): running_loss = 0 for data in dataloader: x = data[x_pos] x = x.to(self.device) loss = 0 self.attacker_optimizer.zero_grad() for target_model in self.target_model_list: target_outputs = data[y_pos] if arbitrary_y else target_model(x) target_outputs = target_outputs.to(self.device) attack_outputs = self.attacker_model(target_outputs) loss = loss + ((x - attack_outputs) ** 2).mean() / len( self.target_model_list ) loss.backward() self.attacker_optimizer.step() running_loss = running_loss + loss.item() / len(dataloader) if self.log_interval != 0 and i % self.log_interval == 0: print(f"epoch {i}: reconstruction_loss {running_loss}") if running_loss < best_loss: best_loss = running_loss best_epoch = i else: if i - best_epoch > self.early_stopping: break
[docs] def attack(self, data_tensor): return self.attacker_model(data_tensor)