Source code for aijack.attack.inversion.gradientinversion_server
from ...manager import BaseManager
from .gradientinversion import GradientInversion_Attack
def _default_gradinent_inversion_attack_on_receive(self):
tmp_result = []
for s in range(self.num_trial_per_communication):
self.reset_seed(s)
try:
tmp_result.append(self.attack())
except OverflowError:
continue
self.attack_results.append(tmp_result)
[docs]def attach_gradient_inversion_attack_to_server(
cls,
x_shape,
attack_function_on_receive=_default_gradinent_inversion_attack_on_receive,
num_trial_per_communication=1,
target_client_id=0,
**gradinvattack_kwargs,
):
"""Wraps the given class in GradientInversionServerWrapper.
Args:
x_shape: input shape of target_model.
attack_function_on_receive (function, optional): a function to execute attack called after
receving the local gradients. Defaults to _default_gradinent_inversion_attack_on_receive.
num_trial_per_communication (int, optional): number of attack trials executed per
communication. Defaults to 1.
target_client_id (int, optional): id of target client. Default to 0.
gradinvattack_kwargs: kwargs for GradientInversion_Attack
Returns:
cls: GradientInversionServerWrapper
"""
class GradientInversionServerWrapper(cls):
def __init__(self, *args, **kwargs):
super(GradientInversionServerWrapper, self).__init__(*args, **kwargs)
self.target_client_id = target_client_id
self.num_trial_per_communication = num_trial_per_communication
self.attacker = GradientInversion_Attack(
self.server_model, x_shape, **gradinvattack_kwargs
)
self.attack_results = []
def change_target_client_id(self, target_client_id):
self.target_client_id = target_client_id
self.attacker.target_model = self.clients[self.target_client_id]
def receive(self, *args, **kwargs):
super(GradientInversionServerWrapper, self).receive(*args, **kwargs)
attack_function_on_receive(self)
def attack(self, **kwargs):
received_gradient = self.uploaded_gradients[self.target_client_id]
received_gradient = [cg.detach() for cg in received_gradient]
return self.attacker.attack(received_gradient, **kwargs)
def group_attack(self, **kwargs):
received_gradient = self.uploaded_gradients[self.target_client_id]
received_gradient = [cg.detach() for cg in received_gradient]
return self.attacker.group_attack(received_gradient, **kwargs)
def reset_seed(self, seed):
self.attacker.reset_seed(seed)
return GradientInversionServerWrapper
[docs]class GradientInversionAttackServerManager(BaseManager):
"""Manager class for Gradient-based model inversion attack"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
[docs] def attach(self, cls):
"""Wraps the given class in GradientInversionServerWrapper.
Returns:
cls: GradientInversionServerWrapper
"""
return attach_gradient_inversion_attack_to_server(
cls, *self.args, **self.kwargs
)