Source code for aijack.defense.sparse.topk
import torch
from ...manager import BaseManager
[docs]def attach_sparse_gradient_to_client(cls, k):
"""Make the client class communicate the sparse gradients.
Args:
cls: client class
k (int): strength of sparcity
Returns:
cls: a class wrapped in SparseGradientClientWrapper
"""
class SparseGradientClientWrapper(cls):
def __init__(self, *args, **kwargs):
super(SparseGradientClientWrapper, self).__init__(*args, **kwargs)
def upload_gradients(self):
"""Upload sparse gradients"""
vanila_gradients = super().upload_gradients()
sparse_gradients = []
sparse_indices = []
for vanila_grad in vanila_gradients:
temp_grad = vanila_grad.reshape(-1)
# only send top-k gradients
topk_indices = torch.topk(
torch.abs(temp_grad), k=int(len(temp_grad) * k)
).indices
sparse_gradients.append(temp_grad[topk_indices].tolist())
sparse_indices.append(topk_indices.tolist())
return sparse_gradients, sparse_indices
return SparseGradientClientWrapper
[docs]def attach_sparse_gradient_to_server(cls):
"""Make the server class communicate the sparse gradients.
Args:
cls: server class
Returns:
cls: a class wrapped in SparseGradientServerWrapper
"""
class SparseGradientServerWrapper(cls):
def __init__(self, *args, **kwargs):
super(SparseGradientServerWrapper, self).__init__(*args, **kwargs)
def _preprocess_local_gradients(self, uploaded_grad):
"""Reconstructs dense gradient from the received sparse gradients
Args:
uploaded_grad (tuple(torch.Tensor, torch.Tensor)): tuple of non-zero gradients and their positions
Returns:
List[torch.Tensor]: list of recovered dense gradients
"""
sparse_gradients_flattend, sparse_indices = uploaded_grad
gradients_reshaped = []
for params, grad, idx in zip(
self.server_model.parameters(),
sparse_gradients_flattend,
sparse_indices,
):
temp_grad = torch.zeros_like(params).reshape(-1)
temp_grad[idx] = torch.Tensor(grad).to(self.device)
gradients_reshaped.append(temp_grad.reshape(params.shape))
return gradients_reshaped
return SparseGradientServerWrapper
[docs]class SparseGradientClientManager(BaseManager):
"""Client-side Manager for sparse gradients."""
[docs] def attach(self, cls):
return attach_sparse_gradient_to_client(cls, *self.args, **self.kwargs)
[docs]class SparseGradientServerManager(BaseManager):
"""Server-side Manager for sparse gradients."""
[docs] def attach(self, cls):
return attach_sparse_gradient_to_server(cls, *self.args, **self.kwargs)