Source code for aijack.attack.freerider.freerider
import copy
import torch
from ...manager import BaseManager
[docs]def attach_freerider_to_client(cls, mu, sigma):
"""Wraps the given class in FreeRiderClientWrapper.
Args:
mu (float): mean of the gaussian distribution used to generate fake gradients
sigma (float): standard deviation of the gaussian distribution used to generate fake gradients
Returns:
cls: a class wrapped in FreeRiderClientWrapper
"""
class FreeRiderClientWrapper(cls):
"""Implementation of Free Rider Attack (https://arxiv.org/abs/1911.12560)"""
def __init__(self, *args, **kwargs):
super(FreeRiderClientWrapper, self).__init__(*args, **kwargs)
self.prev_parameters_to_generate_fake_gradients = None
def upload_gradients(self):
"""Uploads the fake gradients"""
gradients = []
if self.prev_parameters_to_generate_fake_gradients is not None:
for param, prev_param in zip(
self.model.parameters(),
self.prev_parameters_to_generate_fake_gradients,
):
gradients.append(
(prev_param - param) / self.lr
+ (sigma * torch.randn_like(param) + mu)
)
else:
for param in self.model.parameters():
gradients.append(sigma * torch.randn_like(param) + mu)
return gradients
def download(self, new_global_model):
"""Downloads the new global model"""
self.prev_parameters_to_generate_fake_gradients = copy.deepcopy(
list(self.model.parameters())
)
super().download(new_global_model)
return FreeRiderClientWrapper
[docs]class FreeRiderClientManager(BaseManager):
"""Manager class for Free-Rider Attack (https://arxiv.org/abs/1911.12560)"""
[docs] def attach(self, cls):
"""Wraps the given class in FreeRiderClientWrapper.
Returns:
cls: a class wrapped in FreeRiderClientWrapper
"""
return attach_freerider_to_client(cls, *self.args, **self.kwargs)