Source code for aijack.utils.metrics

import numpy as np
import torch
from sklearn.metrics import accuracy_score
from torch.nn import functional as F


[docs]def total_variance(x): """Returns the total variance of the given data Args: x (torch.Tensor): input data Returns: float: total variance of the given data """ dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) return dx + dy
[docs]def crossentropyloss_between_logits(y_pred_logit, y_true_labels, reduction="mean"): """Cross entropy loss for soft labels Based on https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501/2 Args: y_pred_logit (torch.Tensor): predicted logits y_true_labels (torch.Tensor): ground-truth soft labels Returns: torch.Tensor: average cross entropy between y_pred_logit and y_true_labels2 """ results = -torch.sum(F.log_softmax(y_pred_logit, dim=1) * y_true_labels, dim=1) if reduction == "sum": return torch.sum(results) elif reduction == "mean": return torch.mean(results) else: raise NotImplementedError(f"`reduction`={reduction} is not supported.")
[docs]def accuracy_torch_dataloader(model, dataloader, device="cpu", xpos=1, ypos=2): """Calculates the accuracy of the model on the given dataloader Args: model (torch.nn.Module): model to be evaluated dataloader (torch.DataLoader): dataloader to be evaluated device (str, optional): device type. Defaults to "cpu". xpos (int, optional): the positional index of the input in data. Defaults to 1. ypos (int, optional): the positional index of the label in data. Defaults to 2. Returns: float: accuracy """ in_preds = [] in_label = [] with torch.no_grad(): for data in dataloader: inputs = data[xpos] labels = data[ypos] inputs = inputs.to(device) labels = labels.to(device).to(torch.int64) outputs = model(inputs) in_preds.append(outputs) in_label.append(labels) in_preds = torch.cat(in_preds) in_label = torch.cat(in_label) return accuracy_score( np.array(torch.argmax(in_preds, axis=1).cpu()), np.array(in_label.cpu()) )