Source code for aijack.attack.backdoor.dba
import random
import torch
from ...manager import BaseManager
[docs]def attach_dba_to_client(
cls, decomposed_trigger_rules, target_label, poison_ratio, scale
):
"""Wraps the given class in DistributedBackdoorAttackClientWrapper.
Args:
cls: Server class
decomposed_trigger_rules ([function]): list of functions that define the decomposed trigger rules for each client
target_label (int): a label that the attacker want to make the victim model predict when the inupt contains the trigger
poison_ratio (float): a ratio of poisoned samples
scale (_type_): scale for the uploaded gradients
Returns:
cls: a class wrapped in DistributedBackdoorAttackClientWrapper
"""
class DistributedBackdoorAttackClientWrapper(cls):
"""Implementation of https://openreview.net/forum?id=rkgyS0VFvr"""
def __init__(self, *args, **kwargs):
super(DistributedBackdoorAttackClientWrapper, self).__init__(
*args, **kwargs
)
def upload_gradients(self):
"""Uploads the local gradients"""
gradients = []
for param, prev_param in zip(self.model.parameters(), self.prev_parameters):
gradients.append((prev_param - param) / self.lr * scale)
return gradients
def local_train(
self, local_epoch, criterion, trainloader, optimizer, communication_id=0
):
loss_log = []
for _ in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
if random.random() < poison_ratio:
inputs = decomposed_trigger_rules[self.user_id](inputs)
labels = torch.ones_like(labels) * target_label
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
loss_log.append(running_loss / running_data_num)
return loss_log
return DistributedBackdoorAttackClientWrapper
[docs]class DistributedBackdoorAttackClientManager(BaseManager):
"""Manager class for DistributedBackdoorAttack proposed in https://openreview.net/forum?id=rkgyS0VFvr."""
[docs] def attach(self, cls):
"""Wraps the given class in DistributedBackdoorAttackClientWrapper.
Returns:
cls: a class wrapped in DistributedBackdoorAttackClientWrapper
"""
return attach_dba_to_client(cls, *self.args, **self.kwargs)