Source code for aijack.attack.poison.label_flip
import random
import torch
from ...manager import BaseManager
[docs]def attach_label_flip_attack_to_client(
cls, victim_label, target_label=None, class_num=None
):
"""Attaches a label flip attack to a client.
Args:
cls: The client class.
victim_label: The label to be replaced.
target_label: The label to replace the victim label with. If None, a random label will be chosen.
class_num: The number of classes.
Returns:
class: A wrapper class with attached label flip attack.
"""
class LabelFlipAttackClientWrapper(cls):
def __init__(self, *args, **kwargs):
super(LabelFlipAttackClientWrapper, self).__init__(*args, **kwargs)
def local_train(
self, local_epoch, criterion, trainloader, optimizer, communication_id=0
):
for i in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
if target_label is not None:
labels = torch.where(
labels == victim_label, target_label, labels
)
else:
labels = torch.where(
labels == victim_label, random.randint(0, class_num), labels
)
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
print(
f"communication {communication_id}, epoch {i}: client-{self.user_id+1}",
running_loss / running_data_num,
)
return LabelFlipAttackClientWrapper
[docs]class LabelFlipAttackClientManager(BaseManager):
[docs] def attach(self, cls):
return attach_label_flip_attack_to_client(cls, *self.args, **self.kwargs)