Source code for aijack.collaborative.splitnn.api

from ..core.api import BaseFedAPI


[docs]class SplitNNAPI(BaseFedAPI): def __init__(self, clients, optimizers, dataloader, criterion, num_epoch): super().__init__() self.clients = clients self.optimizers = optimizers self.dataloader = dataloader self.criterion = criterion self.num_epoch = num_epoch self.num_clients = len(clients) self.recent_output = None self.loss_log = []
[docs] def local_train(self): for data in self.dataloader: self.zero_grad() inputs, labels = data outputs = self(inputs) loss = self.criterion(outputs, labels) self.backward(loss) self.step() self.loss_log.append(loss.item())
[docs] def run(self): self.train() for _ in range(self.num_epoch): self.local_train()
def __call__(self, *args, **kwds): return self.forward(*args, **kwds)
[docs] def forward(self, x): intermidiate_to_next_client = x for client in self.clients: intermidiate_to_next_client = client.upload(intermidiate_to_next_client) output = intermidiate_to_next_client self.recent_output = output return output
[docs] def backward(self, loss): loss.backward() return self.backward_gradient(self.recent_output.grad)
[docs] def backward_gradient(self, grads_outputs): grad_from_next_client = grads_outputs for i in range(self.num_clients - 1, -1, -1): self.clients[i].download(grad_from_next_client) if i != 0: grad_from_next_client = self.clients[i].distribute() return grad_from_next_client
[docs] def train(self): for client in self.clients: client.train()
[docs] def eval(self): for client in self.clients: client.train()
[docs] def zero_grad(self): for opt in self.optimizers: opt.zero_grad()
[docs] def step(self): for opt in self.optimizers: opt.step()