Source code for aijack.attack.labelleakage.normattack
import torch
from sklearn.metrics import roc_auc_score
from ...manager import BaseManager
[docs]def attach_normattack_to_splitnn(
cls, attack_criterion, target_client_index=0, device="cpu"
):
"""Attaches a normalization attack to a SplitNN model.
Args:
cls: The SplitNN model class.
attack_criterion: The criterion for the attack.
target_client_index (int, optional): Index of the target client. Defaults to 0.
device (str, optional): Device for computation. Defaults to "cpu".
Returns:
class: A wrapper class with attached normalization attack.
"""
class NormAttackSplitNNWrapper(cls):
def __init__(self, *args, **kwargs):
super(NormAttackSplitNNWrapper, self).__init__(*args, **kwargs)
self.attack_criterion = attack_criterion
self.target_client_index = target_client_index
self.device = device
def extract_intermidiate_gradient(self, outputs):
self.backward_gradient(outputs.grad)
return self.clients[self.target_client_index].grad_from_next_client
def attack(self, dataloader):
"""Calculates leak_auc on the given SplitNN model
reference: https://arxiv.org/abs/2102.08504
Args:
dataloader (torch dataloader): dataloader for evaluation
criterion: loss function for training
device: cpu or GPU
Returns:
score: leak auc
"""
epoch_labels = []
epoch_g_norm = []
for i, data in enumerate(dataloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
labels = labels.to(self.device)
outputs = self(inputs)
loss = self.attack_criterion(outputs, labels)
loss.backward()
grad_from_server = self.extract_intermidiate_gradient(outputs)
g_norm = grad_from_server.pow(2).sum(dim=1).sqrt()
epoch_labels.append(labels)
epoch_g_norm.append(g_norm)
epoch_labels = torch.cat(epoch_labels)
epoch_g_norm = torch.cat(epoch_g_norm)
score = roc_auc_score(epoch_labels, epoch_g_norm.view(-1, 1))
return score
return NormAttackSplitNNWrapper
[docs]class NormAttackSplitNNManager(BaseManager):
[docs] def attach(self, cls):
return attach_normattack_to_splitnn(cls, *self.args, **self.kwargs)