2.1.6. aijack.defense.mid package#

2.1.6.1. Submodules#

2.1.6.2. aijack.defense.mid.loss module#

aijack.defense.mid.loss.KL_between_normals(mu_q, sigma_q, mu_p, sigma_p)[source]#
aijack.defense.mid.loss.mib_loss(y, sampled_y_pred, p_z_given_x_mu, p_z_given_x_sigma, approximated_z_mean, approximated_z_sigma, beta=0.001)[source]#

Implementation of MID loss for NN proposed in https://arxiv.org/abs/2009.05241

Parameters
  • y (torch.Tensor) – ground-truth label

  • sampled_y_pred (torch.Tensor) – prdicted output

  • p_z_given_x_mu (torch.Tensor) – mean of z|x

  • p_z_given_x_sigma (torch.Tensor) – standard deviation of z|x

  • approximated_z_mean (torch.Tensor) – approximated mean of z

  • approximated_z_sigma (torch.Tensor) – approximated standard deviation of z

  • beta (float, optional) – _description_. weight of I(Z|X) to 1e-3.

Returns

loss values, bound of I(Z|Y) and bound of I(Z|X)

Return type

float, torch.Tensor, float

2.1.6.3. aijack.defense.mid.nn module#

class aijack.defense.mid.nn.VIB(encoder, decoder, dim_z=256, num_samples=10, beta=0.001)[source]#

Bases: torch.nn.modules.module.Module

Variational Information Bottleneck (VIB) module.

Parameters
  • encoder (torch.nn.Module) – Encoder module.

  • decoder (torch.nn.Module) – Decoder module.

  • dim_z (int, optional) – Dimension of latent variable z. Defaults to 256.

  • num_samples (int, optional) – Number of samples. Defaults to 10.

  • beta (float, optional) – Beta value. Defaults to 1e-3.

forward(x)[source]#

Forward pass of the VIB module.

Parameters

x (torch.Tensor) – Input tensor.

Returns

Output tensor. dict: Dictionary containing sampled outputs and parameters.

Return type

torch.Tensor

get_params_of_p_z_given_x(x)[source]#

Compute parameters of p(z|x).

Parameters

x (torch.Tensor) – Input tensor.

Returns

Tuple containing mean and standard deviation of p(z|x).

Return type

tuple

Raises

ValueError – If the output dimension of encoder is not 2 * dim_z.

loss(y, result_dict)[source]#

Compute loss.

Parameters
  • y (torch.Tensor) – Target tensor.

  • result_dict (dict) – Dictionary containing sampled outputs and parameters.

Returns

Loss value.

Return type

torch.Tensor

sampling_from_encoder(mu, sigma, batch_size)[source]#

Sample from encoder distribution.

Parameters
  • mu (torch.Tensor) – Mean of the distribution.

  • sigma (torch.Tensor) – Standard deviation of the distribution.

  • batch_size (int) – Batch size.

Returns

Sampled tensor from encoder distribution.

Return type

torch.Tensor

2.1.6.4. Module contents#