Source code for aijack.defense.foolsgold.server
import numpy as np
import torch
import torch.nn.functional as F
from ...manager import BaseManager
EPS = 1e-8
[docs]def calculate_cs(cs, num_clients, aggregate_historical_gradients):
for i_idx in range(num_clients):
for j_idx in range(i_idx + 1, num_clients):
cs[i_idx][j_idx] = F.cosine_similarity(
aggregate_historical_gradients[i_idx],
aggregate_historical_gradients[j_idx],
0,
EPS,
)
cs[j_idx][i_idx] = cs[i_idx][j_idx]
return cs
[docs]def normalize_cs(cs, v, num_clients):
for i_idx in range(num_clients):
for j_idx in range(num_clients):
if v[j_idx] > v[i_idx]:
cs[i_idx][j_idx] *= v[i_idx] / v[j_idx]
return cs
[docs]def attach_foolsgold_to_server(cls):
"""Wraps the given class in FoolsGoldServerWrapper.
Returns:
cls: a class wrapped in FoolsGoldServerWrapper
"""
class FoolsGoldServerWrapper(cls):
"""Implementation of https://arxiv.org/abs/1808.04866"""
def __init__(self, *args, **kwargs):
super(FoolsGoldServerWrapper, self).__init__(*args, **kwargs)
tmp_flatten_local_gradient = torch.cat(
[p.view(-1) for p in self.server_model.parameters()]
).to(self.device)
self.aggregate_historical_gradients = [
torch.zeros_like(tmp_flatten_local_gradient)
for i in range(len(self.clients))
]
self.cs = np.zeros((len(self.clients), len(self.clients)))
self.v = np.zeros(len(self.clients))
self.alpha = np.zeros(len(self.clients))
def update(self):
self.update_weight()
self.update_from_gradients()
def update_weight(self):
"""Updates weight for each client given the received local gradients."""
for i, local_gradient in enumerate(self.uploaded_gradients):
self.aggregate_historical_gradients[i] += torch.cat(
[g.to(self.device).view(-1) for g in local_gradient[1]]
).to(self.device)
num_clients = len(self.uploaded_gradients)
self.cs = self.calculate_cs(
self.cs, num_clients, self.aggregate_historical_gradients
)
self.v = np.max(self.cs, axis=1)
self.cs = self.normalize_cs(self.cs, self.v, num_clients)
self.alpha = np.max(self.cs, axis=1)
self.alpha = self.alpha / (np.max(self.alpha) + EPS)
self.weight = self.alpha
return FoolsGoldServerWrapper
[docs]class FoolsGoldServerManager(BaseManager):
"""Manager class for FoolsGold proposed in https://arxiv.org/abs/1808.04866."""
[docs] def attach(self, cls):
"""Wraps the given class in FoolsGoldServerWrapper.
Returns:
cls: a class wrapped in FoolsGoldServerWrapper
"""
return attach_foolsgold_to_server(cls, *self.args, **self.kwargs)