import torch
from torch import nn
from ..core import BaseServer
from ..fedmd.client import initialize_global_logit
[docs]class FedGEMSServer(BaseServer):
def __init__(
self,
clients,
global_model,
len_public_dataloader,
output_dim=1,
self_evaluation_func=None,
base_loss_func=nn.CrossEntropyLoss(),
kldiv_loss_func=nn.KLDivLoss(),
server_id=0,
lr=0.1,
epsilon=0.75,
device="cpu",
):
super(FedGEMSServer, self).__init__(clients, global_model, server_id=server_id)
self.len_public_dataloader = len_public_dataloader
self.lr = lr
self.epsilon = epsilon
self.self_evaluation_func = self_evaluation_func
self.base_loss_func = base_loss_func
self.kldiv_loss_func = kldiv_loss_func
self.output_dim = output_dim
self.device = device
self.global_pool_of_logits = initialize_global_logit(
len_public_dataloader, output_dim, self.device
)
self.predicted_values = initialize_global_logit(
len_public_dataloader, output_dim, self.device
)
[docs] def action(self):
self.distribute()
[docs] def update(self, idxs, x):
"""Register the predicted logits to self.predicted_values"""
self.predicted_values[idxs] = self.server_model(x).detach().to(self.device)
[docs] def distribute(self):
"""Distribute the logits of public dataset to each client."""
for client in self.clients:
client.download(self.predicted_values)
[docs] def self_evaluation_on_the_public_dataset(self, idxs, x, y):
"""Execute self evaluation on the public dataset
Args:
idxs (torch.Tensor): indexs of x
x (torch.Tensor): input data
y (torch.Tensr): labels of x
Returns:
the loss
"""
y_pred = self.server_model(x)
correct_idx, incorrect_idx = self.self_evaluation_func(y_pred, y)
loss_s1 = 0
loss_s2 = 0
loss_s3 = 0
# for each sample that the server predicts correctly
if len(correct_idx) != 0:
loss_s1 += self.base_loss_func(
y_pred[correct_idx], y[correct_idx].to(torch.int64)
)
self.global_pool_of_logits[idxs[correct_idx]] = y_pred[correct_idx].detach()
# for each sample that the server predicts wrongly
s_incorrect_not_star_idx = [
iid.item()
for iid in incorrect_idx
if self.global_pool_of_logits[idxs[iid]][0].item() != float("inf")
]
if len(s_incorrect_not_star_idx) != 0:
loss_s2 += self.epsilon * self.base_loss_func(
y_pred[s_incorrect_not_star_idx],
y[s_incorrect_not_star_idx].to(torch.int64),
) + (1 - self.epsilon) * self.kldiv_loss_func(
self.global_pool_of_logits[idxs[s_incorrect_not_star_idx]]
.softmax(dim=-1)
.log(),
y_pred[s_incorrect_not_star_idx].softmax(dim=-1),
)
s_incorrect_star_idx = list(
set(incorrect_idx.cpu().tolist()) - set(s_incorrect_not_star_idx)
)
if len(s_incorrect_star_idx) != 0:
loss_s3 += self.epsilon * self.base_loss_func(
y_pred[s_incorrect_star_idx], y[s_incorrect_star_idx]
) + (1 - self.epsilon) * self.kldiv_loss_func(
self._get_knowledge_from_clients(
x[s_incorrect_star_idx], y[s_incorrect_star_idx]
)
.softmax(dim=-1)
.log(),
y_pred[s_incorrect_star_idx].softmax(dim=-1),
)
loss = loss_s1 + loss_s2 + loss_s3
return loss
def _get_knowledge_from_clients(self, x, y):
client_weight = torch.zeros(self.num_clients, y.shape[0]).to(self.device)
client_knowledge = torch.zeros(
self.num_clients, y.shape[0], self.output_dim
).to(self.device)
for cid, client in enumerate(self.clients):
y_pred = client.upload(x).to(self.device)
client_knowledge[cid] = y_pred
correct_idx, _ = self.self_evaluation_func(y_pred, y)
if len(correct_idx) != 0:
ep = torch.zeros((y_pred.shape[0])).to(self.device)
ep[correct_idx] += -1 * torch.sum(
y_pred[correct_idx].softmax(dim=-1)
* torch.log(y_pred[correct_idx].softmax(dim=-1)),
dim=1,
)
client_weight[cid, correct_idx] = 1 / ep[correct_idx]
client_weight = (
client_weight.softmax(dim=0)
.reshape(self.num_clients, y.shape[0], 1)
.expand(self.num_clients, y.shape[0], self.output_dim)
)
ensembled_knowledge = torch.sum(
client_weight * client_knowledge,
dim=0,
)
return ensembled_knowledge.detach()