Source code for aijack.collaborative.fedkd.client

import torch.nn.functional as F

from ..fedavg import FedAVGClient


def _adaptive_distillation_loss(y_pred_1, y_pred_2, task_loss_1, task_loss_2):
    return F.kl_div(y_pred_1.softmax(dim=1).log(), y_pred_2.softmax(dim=1)) / (
        task_loss_1 + task_loss_2
    )


[docs]class FedKDClient(FedAVGClient): """Implementation of FedKD (https://arxiv.org/abs/2108.13323)""" def __init__( self, student_model, teacher_model, task_lossfn, student_lr=0.1, teacher_lr=0.1, adaptive_distillation_losses=True, adaptive_hidden_losses=True, gradient_compression_ratio=1.0, user_id=0, send_gradient=True, ): super(FedKDClient, self).__init__( student_model, user_id=user_id, lr=student_lr, send_gradient=send_gradient ) self.student_model = student_model self.teacher_model = teacher_model self.teacher_lr = teacher_lr self.task_lossfn = task_lossfn self.adaptive_distillation_losses = adaptive_distillation_losses self.adaptive_hidden_losses = adaptive_hidden_losses self.gradient_compression_ratio = gradient_compression_ratio self._is_valid_models() def _is_valid_models(self): if self.adaptive_hidden_losses: if not hasattr(self.teacher_model, "get_hidden_states"): raise AttributeError( "If adaptive_hidden_losses=True,\ teacher_model must have `get_hidden_states` method" ) if not hasattr(self.student_model, "get_hidden_states"): raise AttributeError( "If adaptive_hidden_losses=True,\ student_model must have `get_hidden_states` method" )
[docs] def loss(self, x, y): y_pred_teacher = self.teacher_model(x) y_pred_student = self.student_model(x) teacher_loss = 0 student_loss = 0 # task_losses task_loss_teacher = self.task_lossfn(y_pred_teacher, y) task_loss_student = self.task_lossfn(y_pred_student, y) teacher_loss += task_loss_teacher student_loss += task_loss_student # adaptive_distillation_losses if self.adaptive_distillation_losses: adaptive_distillaion_loss_teacher = _adaptive_distillation_loss( y_pred_student, y_pred_teacher, task_loss_student, task_loss_teacher ) adaptive_distillaion_loss_student = _adaptive_distillation_loss( y_pred_teacher, y_pred_student, task_loss_teacher, task_loss_student ) teacher_loss += adaptive_distillaion_loss_teacher student_loss += adaptive_distillaion_loss_student # adaptove_hidden_losses if self.adaptive_hidden_losses: adaptive_hidden_losses_student_teacher = 0 hidden_states_teacher = self.teacher_model.get_hidden_states() hidden_states_student = self.student_model.get_hidden_states() if type(hidden_states_teacher) != list: raise TypeError( "get_hidden_states should return a list of torch.Tensors" ) if type(hidden_states_student) != list: raise TypeError( "get_hidden_states should return a list of torch.Tensors" ) for hst, hss in zip(hidden_states_teacher, hidden_states_student): adaptive_hidden_losses_student_teacher += F.mse_loss(hst, hss) teacher_loss += adaptive_hidden_losses_student_teacher / ( task_loss_student + task_loss_teacher ) student_loss += adaptive_hidden_losses_student_teacher / ( task_loss_student + task_loss_teacher ) return teacher_loss, student_loss