Source code for aijack.collaborative.core.client

from abc import abstractmethod

import torch


[docs]class BaseClient(torch.nn.Module): """Abstract class foe the client of collaborative learning. Args: model (torch.nn.Module): a local model user_id (int, optional): id of this client. Defaults to 0. """ def __init__(self, model, user_id=0): """Initialize BaseClient""" super(BaseClient, self).__init__() self.model = model self.user_id = user_id
[docs] def forward(self, x): return self.model(x)
[docs] def train(self): self.model.train()
[docs] def eval(self): self.model.eval()
[docs] def backward(self, loss): """Execute backward mode automatic differentiation with the give loss. Args: loss (torch.Tensor): the value of calculated loss. """ loss.backward()
[docs] @abstractmethod def upload(self): """Upload the locally learned informatino to the server.""" pass
[docs] @abstractmethod def download(self): """Download the global model from the server.""" pass
[docs] @abstractmethod def local_train(self): pass