Source code for aijack.defense.mid.nn
import torch
from torch import nn
from .loss import mib_loss
[docs]class VIB(nn.Module):
"""
Variational Information Bottleneck (VIB) module.
Args:
encoder (torch.nn.Module): Encoder module.
decoder (torch.nn.Module): Decoder module.
dim_z (int, optional): Dimension of latent variable z. Defaults to 256.
num_samples (int, optional): Number of samples. Defaults to 10.
beta (float, optional): Beta value. Defaults to 1e-3.
"""
def __init__(self, encoder, decoder, dim_z=256, num_samples=10, beta=1e-3):
super(VIB, self).__init__()
self.dim_z = dim_z
self.num_samples = num_samples
self.encoder = encoder
self.decoder = decoder
self.beta = beta
[docs] def get_params_of_p_z_given_x(self, x):
"""
Compute parameters of p(z|x).
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: Tuple containing mean and standard deviation of p(z|x).
Raises:
ValueError: If the output dimension of encoder is not 2 * dim_z.
"""
encoder_output = self.encoder(x)
if encoder_output.shape[1] != self.dim_z * 2:
raise ValueError("the output dimension of encoder must be 2 * dim_z")
mu = encoder_output[:, : self.dim_z]
sigma = torch.nn.functional.softplus(encoder_output[:, self.dim_z :])
return mu, sigma
[docs] def sampling_from_encoder(self, mu, sigma, batch_size):
"""
Sample from encoder distribution.
Args:
mu (torch.Tensor): Mean of the distribution.
sigma (torch.Tensor): Standard deviation of the distribution.
batch_size (int): Batch size.
Returns:
torch.Tensor: Sampled tensor from encoder distribution.
"""
return mu + sigma * torch.normal(
torch.zeros(self.num_samples, batch_size, self.dim_z),
torch.ones(self.num_samples, batch_size, self.dim_z),
)
[docs] def forward(self, x):
"""
Forward pass of the VIB module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
dict: Dictionary containing sampled outputs and parameters.
"""
batch_size = x.size()[0]
# encoder
p_z_given_x_mu, p_z_given_x_sigma = self.get_params_of_p_z_given_x(x)
sampled_encoded_features = self.sampling_from_encoder(
p_z_given_x_mu, p_z_given_x_sigma, batch_size
)
# decoder
sampled_decoded_outputs = self.decoder(sampled_encoded_features)
outputs = torch.mean(sampled_decoded_outputs, dim=0)
if self.training:
return outputs, {
"sampled_decoded_outputs": sampled_decoded_outputs.permute(1, 2, 0),
"sampled_encoded_features": sampled_encoded_features,
"p_z_given_x_mu": p_z_given_x_mu,
"p_z_given_x_sigma": p_z_given_x_sigma,
}
else:
return outputs
[docs] def loss(self, y, result_dict):
"""
Compute loss.
Args:
y (torch.Tensor): Target tensor.
result_dict (dict): Dictionary containing sampled outputs and parameters.
Returns:
torch.Tensor: Loss value.
"""
sampled_y_pred = result_dict["sampled_decoded_outputs"]
p_z_given_x_mu = result_dict["p_z_given_x_mu"]
p_z_given_x_sigma = result_dict["p_z_given_x_sigma"]
approximated_z_mean = torch.zeros_like(p_z_given_x_mu)
approximated_z_sigma = torch.ones_like(p_z_given_x_sigma)
loss, _, _ = mib_loss(
y,
sampled_y_pred,
p_z_given_x_mu,
p_z_given_x_sigma,
approximated_z_mean,
approximated_z_sigma,
beta=self.beta,
)
return loss