Source code for aijack.defense.purifier

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class Purifier_Cifar10(nn.Module): """autoencoder for purification on Cifar10 reference https://arxiv.org/abs/2005.03915 """ def __init__(self): super(Purifier_Cifar10, self).__init__() self.L1 = nn.Linear(10, 7) self.bn1 = nn.BatchNorm1d(7) self.L2 = nn.Linear(7, 4) self.bn2 = nn.BatchNorm1d(4) self.L3 = nn.Linear(4, 7) self.bn3 = nn.BatchNorm1d(7) self.L4 = nn.Linear(7, 10)
[docs] def forward(self, x): # 10 -> 7 x = self.L1(x) x = self.bn1(x) x = F.relu(x) # 7 -> 4 x = self.L2(x) x = self.bn2(x) x = F.relu(x) # 4 -> 7 x = self.L3(x) x = self.bn3(x) x = F.relu(x) # 7 -> 10 x = self.L4(x) return x
[docs]def PurifierLoss( prediction, pred_purified, lam=0.2, purifier_criterion=nn.MSELoss(), accuracy_criterion=nn.CrossEntropyLoss(), ): """basic loss function for purification reference https://arxiv.org/abs/2005.03915 train purifier G against target model F to minimize the following objective function L(G) = E[R(G(F(x)), F(x)) + λC(G(F(x), argmax F(x)))] R is a reconstruction loss function C is a cross entropy loss function λ controls the balance of the two loss functions Args: prediction: predicted value of target model pred_purified: purified predicted value of target model lam: controls the balance of the following two functions purifier_criterion: loss function to reshapre confidense score accuracy_criterion: loss function to preserve the accuracy (C) Return: loss_purifier: weighted average of the two loss function """ loss_1 = purifier_criterion(pred_purified, prediction) loss_2 = accuracy_criterion(pred_purified, torch.argmax(prediction, axis=1)) loss_purifier = loss_1 + lam * loss_2 return loss_purifier