Source code for aijack.collaborative.fedkd.client
import torch.nn.functional as F
from ..fedavg import FedAVGClient
def _adaptive_distillation_loss(y_pred_1, y_pred_2, task_loss_1, task_loss_2):
return F.kl_div(y_pred_1.softmax(dim=1).log(), y_pred_2.softmax(dim=1)) / (
task_loss_1 + task_loss_2
)
[docs]class FedKDClient(FedAVGClient):
"""Implementation of FedKD (https://arxiv.org/abs/2108.13323)"""
def __init__(
self,
student_model,
teacher_model,
task_lossfn,
student_lr=0.1,
teacher_lr=0.1,
adaptive_distillation_losses=True,
adaptive_hidden_losses=True,
gradient_compression_ratio=1.0,
user_id=0,
send_gradient=True,
):
super(FedKDClient, self).__init__(
student_model, user_id=user_id, lr=student_lr, send_gradient=send_gradient
)
self.student_model = student_model
self.teacher_model = teacher_model
self.teacher_lr = teacher_lr
self.task_lossfn = task_lossfn
self.adaptive_distillation_losses = adaptive_distillation_losses
self.adaptive_hidden_losses = adaptive_hidden_losses
self.gradient_compression_ratio = gradient_compression_ratio
self._is_valid_models()
def _is_valid_models(self):
if self.adaptive_hidden_losses:
if not hasattr(self.teacher_model, "get_hidden_states"):
raise AttributeError(
"If adaptive_hidden_losses=True,\
teacher_model must have `get_hidden_states` method"
)
if not hasattr(self.student_model, "get_hidden_states"):
raise AttributeError(
"If adaptive_hidden_losses=True,\
student_model must have `get_hidden_states` method"
)
[docs] def loss(self, x, y):
y_pred_teacher = self.teacher_model(x)
y_pred_student = self.student_model(x)
teacher_loss = 0
student_loss = 0
# task_losses
task_loss_teacher = self.task_lossfn(y_pred_teacher, y)
task_loss_student = self.task_lossfn(y_pred_student, y)
teacher_loss += task_loss_teacher
student_loss += task_loss_student
# adaptive_distillation_losses
if self.adaptive_distillation_losses:
adaptive_distillaion_loss_teacher = _adaptive_distillation_loss(
y_pred_student, y_pred_teacher, task_loss_student, task_loss_teacher
)
adaptive_distillaion_loss_student = _adaptive_distillation_loss(
y_pred_teacher, y_pred_student, task_loss_teacher, task_loss_student
)
teacher_loss += adaptive_distillaion_loss_teacher
student_loss += adaptive_distillaion_loss_student
# adaptove_hidden_losses
if self.adaptive_hidden_losses:
adaptive_hidden_losses_student_teacher = 0
hidden_states_teacher = self.teacher_model.get_hidden_states()
hidden_states_student = self.student_model.get_hidden_states()
if type(hidden_states_teacher) != list:
raise TypeError(
"get_hidden_states should return a list of torch.Tensors"
)
if type(hidden_states_student) != list:
raise TypeError(
"get_hidden_states should return a list of torch.Tensors"
)
for hst, hss in zip(hidden_states_teacher, hidden_states_student):
adaptive_hidden_losses_student_teacher += F.mse_loss(hst, hss)
teacher_loss += adaptive_hidden_losses_student_teacher / (
task_loss_student + task_loss_teacher
)
student_loss += adaptive_hidden_losses_student_teacher / (
task_loss_student + task_loss_teacher
)
return teacher_loss, student_loss