2.1.6. aijack.defense.mid package#
2.1.6.1. Submodules#
2.1.6.2. aijack.defense.mid.loss module#
- aijack.defense.mid.loss.mib_loss(y, sampled_y_pred, p_z_given_x_mu, p_z_given_x_sigma, approximated_z_mean, approximated_z_sigma, beta=0.001)[source]#
Implementation of MID loss for NN proposed in https://arxiv.org/abs/2009.05241
- Parameters
y (torch.Tensor) – ground-truth label
sampled_y_pred (torch.Tensor) – prdicted output
p_z_given_x_mu (torch.Tensor) – mean of z|x
p_z_given_x_sigma (torch.Tensor) – standard deviation of z|x
approximated_z_mean (torch.Tensor) – approximated mean of z
approximated_z_sigma (torch.Tensor) – approximated standard deviation of z
beta (float, optional) – _description_. weight of I(Z|X) to 1e-3.
- Returns
loss values, bound of I(Z|Y) and bound of I(Z|X)
- Return type
float, torch.Tensor, float
2.1.6.3. aijack.defense.mid.nn module#
- class aijack.defense.mid.nn.VIB(encoder, decoder, dim_z=256, num_samples=10, beta=0.001)[source]#
Bases:
torch.nn.modules.module.Module
Variational Information Bottleneck (VIB) module.
- Parameters
encoder (torch.nn.Module) – Encoder module.
decoder (torch.nn.Module) – Decoder module.
dim_z (int, optional) – Dimension of latent variable z. Defaults to 256.
num_samples (int, optional) – Number of samples. Defaults to 10.
beta (float, optional) – Beta value. Defaults to 1e-3.
- forward(x)[source]#
Forward pass of the VIB module.
- Parameters
x (torch.Tensor) – Input tensor.
- Returns
Output tensor. dict: Dictionary containing sampled outputs and parameters.
- Return type
torch.Tensor
- get_params_of_p_z_given_x(x)[source]#
Compute parameters of p(z|x).
- Parameters
x (torch.Tensor) – Input tensor.
- Returns
Tuple containing mean and standard deviation of p(z|x).
- Return type
tuple
- Raises
ValueError – If the output dimension of encoder is not 2 * dim_z.
- loss(y, result_dict)[source]#
Compute loss.
- Parameters
y (torch.Tensor) – Target tensor.
result_dict (dict) – Dictionary containing sampled outputs and parameters.
- Returns
Loss value.
- Return type
torch.Tensor
- sampling_from_encoder(mu, sigma, batch_size)[source]#
Sample from encoder distribution.
- Parameters
mu (torch.Tensor) – Mean of the distribution.
sigma (torch.Tensor) – Standard deviation of the distribution.
batch_size (int) – Batch size.
- Returns
Sampled tensor from encoder distribution.
- Return type
torch.Tensor