Source code for aijack.defense.mid.loss

import torch


[docs]def KL_between_normals(mu_q, sigma_q, mu_p, sigma_p): k = mu_q.size(1) mu_diff = mu_p - mu_q mu_diff_sq = torch.mul(mu_diff, mu_diff) logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1) logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1) fs = torch.sum(torch.div(sigma_q**2, sigma_p**2), dim=1) + torch.sum( torch.div(mu_diff_sq, sigma_p**2), dim=1 ) two_kl = fs - k + logdet_sigma_p - logdet_sigma_q return two_kl * 0.5
[docs]def mib_loss( y, sampled_y_pred, p_z_given_x_mu, p_z_given_x_sigma, approximated_z_mean, approximated_z_sigma, beta=1e-3, ): """Implementation of MID loss for NN proposed in https://arxiv.org/abs/2009.05241 Args: y (torch.Tensor): ground-truth label sampled_y_pred (torch.Tensor): prdicted output p_z_given_x_mu (torch.Tensor): mean of z|x p_z_given_x_sigma (torch.Tensor): standard deviation of z|x approximated_z_mean (torch.Tensor): approximated mean of z approximated_z_sigma (torch.Tensor): approximated standard deviation of z beta (float, optional): _description_. weight of I(Z|X) to 1e-3. Returns: float, torch.Tensor, float: loss values, bound of I(Z|Y) and bound of I(Z|X) """ I_ZX_bound = torch.mean( KL_between_normals( p_z_given_x_mu, p_z_given_x_sigma, approximated_z_mean, approximated_z_sigma ) ) loss = torch.nn.CrossEntropyLoss(reduction="none") cross_entropy_loss = loss( sampled_y_pred, y[:, None].expand(-1, sampled_y_pred.size()[-1]) ) cross_entropy_loss_mc = torch.mean(cross_entropy_loss, dim=-1) minus_I_ZY_bound = torch.mean(cross_entropy_loss_mc, dim=0) return ( torch.mean(minus_I_ZY_bound + beta * I_ZX_bound), minus_I_ZY_bound, I_ZX_bound, )