Source code for aijack.attack.inversion.gradientinversion_server

from ...manager import BaseManager
from .gradientinversion import GradientInversion_Attack


def _default_gradinent_inversion_attack_on_receive(self):
    tmp_result = []
    for s in range(self.num_trial_per_communication):
        self.reset_seed(s)
        try:
            tmp_result.append(self.attack())
        except OverflowError:
            continue
    self.attack_results.append(tmp_result)


[docs]def attach_gradient_inversion_attack_to_server( cls, x_shape, attack_function_on_receive=_default_gradinent_inversion_attack_on_receive, num_trial_per_communication=1, target_client_id=0, **gradinvattack_kwargs, ): """Wraps the given class in GradientInversionServerWrapper. Args: x_shape: input shape of target_model. attack_function_on_receive (function, optional): a function to execute attack called after receving the local gradients. Defaults to _default_gradinent_inversion_attack_on_receive. num_trial_per_communication (int, optional): number of attack trials executed per communication. Defaults to 1. target_client_id (int, optional): id of target client. Default to 0. gradinvattack_kwargs: kwargs for GradientInversion_Attack Returns: cls: GradientInversionServerWrapper """ class GradientInversionServerWrapper(cls): def __init__(self, *args, **kwargs): super(GradientInversionServerWrapper, self).__init__(*args, **kwargs) self.target_client_id = target_client_id self.num_trial_per_communication = num_trial_per_communication self.attacker = GradientInversion_Attack( self.server_model, x_shape, **gradinvattack_kwargs ) self.attack_results = [] def change_target_client_id(self, target_client_id): self.target_client_id = target_client_id self.attacker.target_model = self.clients[self.target_client_id] def receive(self, *args, **kwargs): super(GradientInversionServerWrapper, self).receive(*args, **kwargs) attack_function_on_receive(self) def attack(self, **kwargs): received_gradient = self.uploaded_gradients[self.target_client_id] received_gradient = [cg.detach() for cg in received_gradient] return self.attacker.attack(received_gradient, **kwargs) def group_attack(self, **kwargs): received_gradient = self.uploaded_gradients[self.target_client_id] received_gradient = [cg.detach() for cg in received_gradient] return self.attacker.group_attack(received_gradient, **kwargs) def reset_seed(self, seed): self.attacker.reset_seed(seed) return GradientInversionServerWrapper
[docs]class GradientInversionAttackServerManager(BaseManager): """Manager class for Gradient-based model inversion attack""" def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs
[docs] def attach(self, cls): """Wraps the given class in GradientInversionServerWrapper. Returns: cls: GradientInversionServerWrapper """ return attach_gradient_inversion_attack_to_server( cls, *self.args, **self.kwargs )