Source code for aijack.attack.inversion.utils.datarepextractor

import torch


[docs]class DataRepExtractor: def __init__(self, net, num_fc_layers=1, m=1, bias=True): self.net = net self.num_fc_layers = num_fc_layers self.m = m self.bias = bias # dl_dw = torch.autograd.grad(loss, net.parameters(), retain_graph=True) # dl_dw = [g.detach() for g in dl_dw] # extractor = DataRepExtractor(net, num_fc_layers, m, bias) # datarep = extractor.extract_datarep(dl_dw)
[docs] def get_dldw(self, loss): dldw = torch.autograd.grad(loss, self.net.parameters(), retain_graph=True) dldw = [g.detach() for g in dldw] return dldw
[docs] def extract_datarep(self, dldw): max_idx = torch.argmax(dldw[-2].norm(2, dim=1)) rep_1 = dldw[-2][max_idx, :].reshape(1, -1) reps = [rep_1] for i in range(1, self.num_fc_layers): grad_idx = -2 * (i + 1) if self.bias else -2 * i + 1 rep_i = ( dldw[grad_idx][torch.topk(reps[-1].norm(2, dim=0), self.m)[1], :] .mean(dim=0) .reshape(1, -1) ) reps.append(rep_i) return reps