Source code for aijack.collaborative.fedmd.server
from ...manager import BaseManager
from ..core import BaseServer
from ..core.utils import GLOBAL_LOGIT_TAG, LOCAL_LOGIT_TAG
[docs]class FedMDServer(BaseServer):
def __init__(
self,
clients,
server_model=None,
server_id=0,
device="cpu",
):
super(FedMDServer, self).__init__(clients, server_model, server_id=server_id)
self.device = device
self.uploaded_logits = []
[docs] def forward(self, x):
if self.server_model is not None:
return self.server_model(x)
else:
return None
[docs] def action(self):
self.receive()
self.update()
self.distribute()
[docs] def receive(self):
self.uploaded_logits = [client.upload() for client in self.clients]
[docs] def update(self):
len_clients = len(self.clients)
self.consensus = self.uploaded_logits[0] / len_clients
for logit in self.uploaded_logits[1:]:
self.consensus += logit / len_clients
[docs] def distribute(self):
"""Distribute the logits of public dataset to each client."""
for client in self.clients:
client.download(self.consensus)
[docs]def attach_mpi_to_fedmdserver(cls):
class MPIFedMDServerWrapper(cls):
"""MPI Wrapper for FedMDServer"""
def __init__(self, comm, *args, **kwargs):
super(MPIFedMDServerWrapper, self).__init__(*args, **kwargs)
self.comm = comm
self.num_clients = len(self.clients)
self.round = 0
def action(self):
self.mpi_receive()
self.update()
self.mpi_distribute()
self.round += 1
def mpi_receive(self):
self.mpi_receive_local_logits()
def mpi_receive_local_logits(self):
self.uploaded_logits = []
while len(self.uploaded_logits) < self.num_clients:
received_logits = self.comm.recv(tag=LOCAL_LOGIT_TAG)
self.uploaded_logits.append(received_logits)
def mpi_distribute(self):
for client_id in self.clients:
self.comm.send(self.consensus, dest=client_id, tag=GLOBAL_LOGIT_TAG)
return MPIFedMDServerWrapper
[docs]class MPIFedMDServerManager(BaseManager):
[docs] def attach(self, cls):
return attach_mpi_to_fedmdserver(cls, *self.args, **self.kwargs)