Source code for aijack.defense.mid.nn
import torch
from torch import nn
from .loss import mib_loss
[docs]class VIB(nn.Module):
    """
    Variational Information Bottleneck (VIB) module.
    Args:
        encoder (torch.nn.Module): Encoder module.
        decoder (torch.nn.Module): Decoder module.
        dim_z (int, optional): Dimension of latent variable z. Defaults to 256.
        num_samples (int, optional): Number of samples. Defaults to 10.
        beta (float, optional): Beta value. Defaults to 1e-3.
    """
    def __init__(self, encoder, decoder, dim_z=256, num_samples=10, beta=1e-3):
        super(VIB, self).__init__()
        self.dim_z = dim_z
        self.num_samples = num_samples
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta
[docs]    def get_params_of_p_z_given_x(self, x):
        """
        Compute parameters of p(z|x).
        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            tuple: Tuple containing mean and standard deviation of p(z|x).
        Raises:
            ValueError: If the output dimension of encoder is not 2 * dim_z.
        """
        encoder_output = self.encoder(x)
        if encoder_output.shape[1] != self.dim_z * 2:
            raise ValueError("the output dimension of encoder must be 2 * dim_z")
        mu = encoder_output[:, : self.dim_z]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dim_z :])
        return mu, sigma 
[docs]    def sampling_from_encoder(self, mu, sigma, batch_size):
        """
        Sample from encoder distribution.
        Args:
            mu (torch.Tensor): Mean of the distribution.
            sigma (torch.Tensor): Standard deviation of the distribution.
            batch_size (int): Batch size.
        Returns:
            torch.Tensor: Sampled tensor from encoder distribution.
        """
        return mu + sigma * torch.normal(
            torch.zeros(self.num_samples, batch_size, self.dim_z),
            torch.ones(self.num_samples, batch_size, self.dim_z),
        ) 
[docs]    def forward(self, x):
        """
        Forward pass of the VIB module.
        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: Output tensor.
            dict: Dictionary containing sampled outputs and parameters.
        """
        batch_size = x.size()[0]
        # encoder
        p_z_given_x_mu, p_z_given_x_sigma = self.get_params_of_p_z_given_x(x)
        sampled_encoded_features = self.sampling_from_encoder(
            p_z_given_x_mu, p_z_given_x_sigma, batch_size
        )
        # decoder
        sampled_decoded_outputs = self.decoder(sampled_encoded_features)
        outputs = torch.mean(sampled_decoded_outputs, dim=0)
        if self.training:
            return outputs, {
                "sampled_decoded_outputs": sampled_decoded_outputs.permute(1, 2, 0),
                "sampled_encoded_features": sampled_encoded_features,
                "p_z_given_x_mu": p_z_given_x_mu,
                "p_z_given_x_sigma": p_z_given_x_sigma,
            }
        else:
            return outputs 
[docs]    def loss(self, y, result_dict):
        """
        Compute loss.
        Args:
            y (torch.Tensor): Target tensor.
            result_dict (dict): Dictionary containing sampled outputs and parameters.
        Returns:
            torch.Tensor: Loss value.
        """
        sampled_y_pred = result_dict["sampled_decoded_outputs"]
        p_z_given_x_mu = result_dict["p_z_given_x_mu"]
        p_z_given_x_sigma = result_dict["p_z_given_x_sigma"]
        approximated_z_mean = torch.zeros_like(p_z_given_x_mu)
        approximated_z_sigma = torch.ones_like(p_z_given_x_sigma)
        loss, _, _ = mib_loss(
            y,
            sampled_y_pred,
            p_z_given_x_mu,
            p_z_given_x_sigma,
            approximated_z_mean,
            approximated_z_sigma,
            beta=self.beta,
        )
        return loss