Source code for aijack.defense.purifier
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class Purifier_Cifar10(nn.Module):
"""autoencoder for purification on Cifar10
reference https://arxiv.org/abs/2005.03915
"""
def __init__(self):
super(Purifier_Cifar10, self).__init__()
self.L1 = nn.Linear(10, 7)
self.bn1 = nn.BatchNorm1d(7)
self.L2 = nn.Linear(7, 4)
self.bn2 = nn.BatchNorm1d(4)
self.L3 = nn.Linear(4, 7)
self.bn3 = nn.BatchNorm1d(7)
self.L4 = nn.Linear(7, 10)
[docs] def forward(self, x):
# 10 -> 7
x = self.L1(x)
x = self.bn1(x)
x = F.relu(x)
# 7 -> 4
x = self.L2(x)
x = self.bn2(x)
x = F.relu(x)
# 4 -> 7
x = self.L3(x)
x = self.bn3(x)
x = F.relu(x)
# 7 -> 10
x = self.L4(x)
return x
[docs]def PurifierLoss(
prediction,
pred_purified,
lam=0.2,
purifier_criterion=nn.MSELoss(),
accuracy_criterion=nn.CrossEntropyLoss(),
):
"""basic loss function for purification
reference https://arxiv.org/abs/2005.03915
train purifier G against target model F to minimize the
following objective function
L(G) = E[R(G(F(x)), F(x)) + λC(G(F(x), argmax F(x)))]
R is a reconstruction loss function
C is a cross entropy loss function
λ controls the balance of the two loss functions
Args:
prediction: predicted value of target model
pred_purified: purified predicted value of target model
lam: controls the balance of the following two functions
purifier_criterion: loss function to reshapre confidense score
accuracy_criterion: loss function to preserve the accuracy (C)
Return:
loss_purifier: weighted average of the two loss function
"""
loss_1 = purifier_criterion(pred_purified, prediction)
loss_2 = accuracy_criterion(pred_purified, torch.argmax(prediction, axis=1))
loss_purifier = loss_1 + lam * loss_2
return loss_purifier