Source code for aijack.attack.evasion.diva
import torch
from ..base_attack import BaseAttacker
[docs]class DIVAWhiteBoxAttacker(BaseAttacker):
"""Class implementing the DIVA white-box attack.
This class provides functionality to perform the DIVA white-box attack on a target model.
Args:
target_model (torch.nn.Module): The target model to be attacked.
target_model_on_edge (torch.nn.Module): The target model deployed on the edge.
c (float, optional): The trade-off parameter between origin and edge predictions. Defaults to 1.0.
num_itr (int, optional): The number of iterations for the attack. Defaults to 1000.
eps (float, optional): The maximum perturbation allowed. Defaults to 0.1.
lam (float, optional): The step size for gradient updates. Defaults to 0.01.
device (str, optional): The device to perform computation on. Defaults to "cpu".
Attributes:
target_model (torch.nn.Module): The target model to be attacked.
target_model_on_edge (torch.nn.Module): The target model deployed on the edge.
c (float): The trade-off parameter between origin and edge predictions.
num_itr (int): The number of iterations for the attack.
eps (float): The maximum perturbation allowed.
lam (float): The step size for gradient updates.
device (str): The device to perform computation on.
"""
def __init__(
self,
target_model,
target_model_on_edge,
c=1.0,
num_itr=1000,
eps=0.1,
lam=0.01,
device="cpu",
):
super().__init__(target_model)
self.target_model_on_edge = target_model_on_edge
self.c = c
self.num_itr = num_itr
self.eps = eps
self.lam = lam
self.device = device
[docs] def attack(self, data):
"""Performs the DIVA white-box attack on input data.
Args:
data (tuple): A tuple containing input data and corresponding labels.
Returns:
tuple: A tuple containing the adversarial examples and attack logs.
"""
x, y = data
x = x.to(self.device)
y = y.to(self.device)
x_origin = torch.clone(x)
log_loss = []
log_perturbation = []
for _ in range(self.num_itr):
x = x.detach().to(self.device)
x.requires_grad = True
origin_pred = self.target_model(x)
edge_pred = self.target_model_on_edge(x)
loss = origin_pred[:, y] - self.c * edge_pred[:, y]
loss.backward()
grad = x.grad
with torch.no_grad():
x += self.lam * grad
x = torch.clamp(x, x_origin - self.eps, x_origin + self.eps)
log_loss.append(loss.item())
log_perturbation.append(torch.mean((x - x_origin).abs()).item())
if origin_pred.argmax().item() != edge_pred.argmax().item():
break
return x, {"log_loss": log_loss, "log_perturbation": log_perturbation}