Source code for aijack.defense.soteria.soteria_client
import torch
from ...manager import BaseManager
[docs]def attach_soteria_to_client(
cls,
input_layer,
perturbed_layer,
epsilon=0.2,
target_layer_name=None,
):
"""
Attaches the Soteria wrapper to the client class.
Args:
cls: The client class to which Soteria will be attached.
input_layer (str): Name of the input layer.
perturbed_layer (str): Name of the perturbed layer.
epsilon (float, optional): Privacy budget epsilon. Defaults to 0.2.
target_layer_name (str, optional): Name of the target layer. Defaults to None.
Returns:
class: Client class with Soteria wrapper attached.
"""
class SoteriaClientWrapper(cls):
"""Implementation of https://arxiv.org/pdf/2012.06043.pdf"""
def __init__(self, *args, **kwargs):
super(SoteriaClientWrapper, self).__init__(*args, **kwargs)
self.input_layer = input_layer
self.perturbed_layer = perturbed_layer
self.epsilon = epsilon
self.target_layer_name = target_layer_name
self._set_hook()
def _set_hook(self):
self.inputs = {}
self.outputs = {}
def get_input(name):
def hook(model, x, output):
if not x[0].requires_grad:
raise ValueError("x.requires_grad should be True")
self.inputs[name] = x[0]
return hook
def get_output(name):
def hook(model, x, output):
self.outputs[name] = output
return hook
getattr(self.model, self.input_layer).register_forward_hook(
get_input(self.input_layer)
)
getattr(self.model, self.perturbed_layer).register_forward_hook(
get_output(self.perturbed_layer)
)
def action_before_lossbackward(self):
input_data = self.inputs[self.input_layer]
feature = self.outputs[self.perturbed_layer]
mask = torch.zeros_like(feature)
r_dfr_dx_norm = torch.zeros_like(feature)
rep_size = feature.shape[1]
for i in range(rep_size):
mask[:, i] = 1
feature.backward(
mask, retain_graph=True
) # calc the derivative of feature_2 @ df_dtarget
dfri_dx = input_data.grad.data
r_dfr_dx_norm[:, i] = feature[:, i] / torch.norm(
dfri_dx.view(dfri_dx.shape[0], -1), dim=1
)
self.model.zero_grad()
input_data.grad.data.zero_()
mask[:, i] = 0
self.topk_idxs = torch.topk(
r_dfr_dx_norm.mean(dim=0), int(rep_size * self.epsilon)
)[1]
def action_after_lossbackward(self, target_layer_name=None):
target_layer_name = (
f"{self.perturbed_layer}.weight"
if target_layer_name is None
else target_layer_name
)
dl_dw = {
layer_name: params.grad
for layer_name, params in zip(
self.model.state_dict(), self.model.parameters()
)
}
dl_dw[target_layer_name][self.topk_idxs, :] = 0
def backward(self, loss):
self.action_before_lossbackward()
super().backward(loss)
self.action_after_lossbackward(self.target_layer_name)
def local_train(
self, local_epoch, criterion, trainloader, optimizer, communication_id=0
):
loss_log = []
for _ in range(local_epoch):
running_loss = 0.0
running_data_num = 0
for _, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(self.device)
inputs.retain_grad()
labels = labels.to(self.device)
optimizer.zero_grad()
self.zero_grad()
outputs = self(inputs)
loss = criterion(outputs, labels)
self.backward(loss)
optimizer.step()
running_loss += loss.item()
running_data_num += inputs.shape[0]
loss_log.append(running_loss / running_data_num)
return loss_log
return SoteriaClientWrapper
[docs]class SoteriaClientManager(BaseManager):
[docs] def attach(self, cls):
return attach_soteria_to_client(cls, *self.args, **self.kwargs)