Source code for aijack.defense.dp.manager.dataloader

import numpy as np
import torch
from torch.utils.data import DataLoader


[docs]class PoissonSampler: def __init__(self, dataset, lot_size, iterations): self.dataset_size = len(dataset) self.lot_size = lot_size self.iterations = iterations def __iter__(self): for _ in range(self.iterations): indices = np.where( torch.rand(self.dataset_size) < (self.lot_size / self.dataset_size) )[0] if indices.size > 0: yield indices def __len__(self): return self.iterations
[docs]class DPWrapperLotDataIterator: def __init__(self, original_iterator, dp_optimizer): self.original_iterator = original_iterator self.dp_optimizer = dp_optimizer self.init_flag = True def __iter__(self): return self def _reset(self, *args, **kwargs): return self.original_iterator._reset(*args, **kwargs) def _next_index(self, *args, **kwargs): return self.original_iterator._next_index(*args, **kwargs) def __next__(self): if not self.init_flag: self.dp_optimizer.step_for_lot() else: self.init_flag = False data = self.original_iterator.__next__() self.dp_optimizer.zero_grad_for_lot() return data def __len__(self): return self.original_iterator.__len__() def __getstate__(self): raise self.original_iterator.__getstate__()
[docs]class LotDataLoader(DataLoader): def __init__(self, dp_optimizer, *args, **kwargs): super(LotDataLoader, self).__init__(*args, **kwargs) self.dp_optimizer = dp_optimizer self.init_flag = True def __iter__(self): return DPWrapperLotDataIterator( super(LotDataLoader, self).__iter__(), self.dp_optimizer )