Source code for aijack.collaborative.core.client
from abc import abstractmethod
import torch
[docs]class BaseClient(torch.nn.Module):
"""Abstract class foe the client of collaborative learning.
Args:
model (torch.nn.Module): a local model
user_id (int, optional): id of this client. Defaults to 0.
"""
def __init__(self, model, user_id=0):
"""Initialize BaseClient"""
super(BaseClient, self).__init__()
self.model = model
self.user_id = user_id
[docs] def forward(self, x):
return self.model(x)
[docs] def train(self):
self.model.train()
[docs] def eval(self):
self.model.eval()
[docs] def backward(self, loss):
"""Execute backward mode automatic differentiation with the give loss.
Args:
loss (torch.Tensor): the value of calculated loss.
"""
loss.backward()
[docs] @abstractmethod
def upload(self):
"""Upload the locally learned informatino to the server."""
pass
[docs] @abstractmethod
def download(self):
"""Download the global model from the server."""
pass
[docs] @abstractmethod
def local_train(self):
pass