Source code for aijack.utils.dataloader
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from .utils import NumpyDataset
[docs]def prepareFederatedMNISTDataloaders(
client_num=2,
local_label_num=2,
local_data_num=20,
batch_size=1,
test_batch_size=16,
path="MNIST/.",
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
),
seed=0,
return_idx=False,
):
np.random.seed(seed)
random.seed(seed)
at_t_dataset_train = torchvision.datasets.MNIST(
root=path, train=True, download=download
)
at_t_dataset_test = torchvision.datasets.MNIST(
root=path, train=False, download=download
)
X = at_t_dataset_train.train_data.numpy()
y = at_t_dataset_train.train_labels.numpy()
test_set = NumpyDataset(
at_t_dataset_test.test_data.numpy(),
at_t_dataset_test.test_labels.numpy(),
transform=transform,
return_idx=return_idx,
)
testloader = torch.utils.data.DataLoader(
test_set, batch_size=test_batch_size, shuffle=True, num_workers=0
)
trainloaders = []
train_sizes = []
idx_used = []
for c in range(client_num):
assigned_labels = random.sample(range(10), local_label_num)
print(f"the labels that client_id={c} has are: ", assigned_labels)
idx = np.concatenate([np.where(y == al)[0] for al in assigned_labels])
assigned_idx = random.sample(list(set(idx) - set(idx_used)), local_data_num)
temp_trainset = NumpyDataset(
X[assigned_idx], y[assigned_idx], transform=transform, return_idx=return_idx
)
temp_trainloader = torch.utils.data.DataLoader(
temp_trainset, batch_size=batch_size, shuffle=True, num_workers=0
)
trainloaders.append(temp_trainloader)
train_sizes.append(len(temp_trainset))
idx_used += assigned_idx
return X, y, trainloaders, testloader, train_sizes, idx_used