Source code for aijack.collaborative.splitnn.client

from ..core import BaseClient


[docs]class SplitNNClient(BaseClient): def __init__(self, model, user_id=0): super().__init__(model, user_id=user_id) self.own_intermidiate = None self.prev_intermidiate = None self.grad_from_next_client = None
[docs] def forward(self, prev_intermediate): """Send intermidiate tensor to the server Args: x (torch.Tensor): the input data Returns: intermidiate_to_next_client (torch.Tensor): the output of client-side model which the client sent to the server """ self.prev_intermidiate = prev_intermediate self.own_intermidiate = self.model(prev_intermediate) intermidiate_to_next_client = self.own_intermidiate.detach().requires_grad_() return intermidiate_to_next_client
[docs] def upload(self, x): return self.forward(x)
[docs] def download(self, grad_from_next_client): self._client_backward(grad_from_next_client)
def _client_backward(self, grad_from_next_client): """Client-side back propagation Args: grad_from_server: gradient which the server send to the client """ self.grad_from_next_client = grad_from_next_client self.own_intermidiate.backward(grad_from_next_client)
[docs] def distribute(self): return self.prev_intermidiate.grad.clone()