Source code for aijack.attack.poison.mapf

import torch

from ...manager import BaseManager


[docs]def attach_mapf_to_client(cls, lam, base_model_parameters=None): """Attaches a MAPF attack to a client. Args: cls: The client class. lam (float): The lambda parameter for the attack. base_model_parameters (list, optional): Base model parameters for parameter flipping. If None, random parameters will be generated. Defaults to None. Returns: class: A wrapper class with attached MAPF attack. """ class MAPFClientWrapper(cls): """Implementation of MAPF proposed in https://arxiv.org/pdf/2203.08669.pdf""" def __init__(self, *args, **kwargs): super(MAPFClientWrapper, self).__init__(*args, **kwargs) if base_model_parameters is None: self.base_model_parameters = [ torch.randn_like(p) for p in self.model.parameters() ] else: self.base_model_parameters = base_model_parameters def upload_gradients(self): """Upload the local gradients""" gradients = [] for param, base_param in zip( self.model.parameters(), self.base_model_parameters ): gradients.append((base_param - param) * lam) return gradients return MAPFClientWrapper
[docs]class MAPFClientWrapper(BaseManager):
[docs] def attach(self, cls): return attach_mapf_to_client(cls, *self.args, **self.kwargs)