Source code for aijack.attack.backdoor.modelreplacement
import torch
from ...manager import BaseManager
[docs]def l2norm_checker(client):
l2 = torch.tensor(0.0, requires_grad=True)
for param, prev_param in zip(client.model.parameters(), client.prev_parameters):
l2 = l2 + torch.norm(param - prev_param, 2)
return l2
[docs]def attach_modelreplacement_to_client(
cls,
alpha,
gamma,
criterion_anomaly_detection=l2norm_checker,
reference_dataloader=None,
eps=1e-6,
):
"""Wraps the given class in ModelReplacementAttackClientWrapper.
Args:
cls: Client class
Returns:
cls: a class wrapped in ModelReplacementAttackClientWrapper
"""
class ModelReplacementAttackClientWrapper(cls):
"""Implementation of https://proceedings.mlr.press/v108/bagdasaryan20a/bagdasaryan20a.pdf"""
def __init__(self, *args, **kwargs):
super(ModelReplacementAttackClientWrapper, self).__init__(*args, **kwargs)
def upload_gradients(self):
"""Uploads the local gradients"""
gradients = []
for param, prev_param in zip(self.model.parameters(), self.prev_parameters):
gradients.append(gamma * (prev_param - param) / (self.lr))
return gradients
def local_train(
self, local_epoch, criterion, trainloader, optimizer, communication_id=0
):
loss_log = []
for _ in range(local_epoch):
if reference_dataloader is not None:
running_loss = 0.0
running_data_num = 0
with torch.no_grad():
for data in reference_dataloader:
inputs, labels = data
inputs = inputs.to(self.device)
labels = labels.to(self.labels)
outputs = self(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
running_data_num += inputs.shape()[0]
if running_loss / running_data_num <= eps:
break
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = alpha * criterion(outputs, labels)
loss += (1 - alpha) * criterion_anomaly_detection(self)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
loss_log.append(running_loss / running_data_num)
return loss_log
return ModelReplacementAttackClientWrapper
[docs]class ModelReplacementAttackClientManager(BaseManager):
"""Manager class for DistributedBackdoorAttack proposed in
https://proceedings.mlr.press/v108/bagdasaryan20a/bagdasaryan20a.pdf."""
[docs] def attach(self, cls):
"""Wraps the given class in ModelReplacementAttackClientWrapper.
Returns:
cls: a class wrapped in ModelReplacementAttackClientWrapper
"""
return attach_modelreplacement_to_client(cls, *self.args, **self.kwargs)