Source code for aijack.attack.inversion.utils.regularization

import torch


[docs]def total_variance(x): """Computes the total variance of an input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The total variance. """ dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) return dx + dy
[docs]def label_matching(pred, label): """Computes the label matching loss between predicted and target labels. Args: pred (torch.Tensor): Predicted labels. label (torch.Tensor): Target labels. Returns: torch.Tensor: The label matching loss. """ onehot_label = torch.eye(pred.shape[-1])[label] onehot_label = onehot_label.to(pred.device) return torch.sqrt(torch.sum((pred - onehot_label) ** 2))
[docs]def group_consistency(x, group_x): """Computes the group consistency loss between an input and a group of inputs. Args: x (torch.Tensor): The input tensor. group_x (list): List of tensors representing the group. Returns: torch.Tensor: The group consistency loss. """ mean_group_x = sum(group_x) / len(group_x) return torch.norm(x - mean_group_x, p=2)
[docs]def bn_regularizer(feature_maps, bn_layers): """Computes the batch normalization regularizer loss. Args: feature_maps (list): List of feature maps. bn_layers (list): List of batch normalization layers. Returns: torch.Tensor: The batch normalization regularizer loss. """ bn_reg = 0 for i, layer in enumerate(bn_layers): fm = feature_maps[i] if len(fm.shape) == 3: dim = [0, 2] elif len(fm.shape) == 4: dim = [0, 2, 3] elif len(fm.shape) == 5: dim = [0, 2, 3, 4] bn_reg += torch.norm(fm.mean(dim=dim) - layer.state_dict()["running_mean"], p=2) bn_reg += torch.norm(fm.var(dim=dim) - layer.state_dict()["running_var"], p=2) return bn_reg