Source code for aijack.defense.paillier.fed_wrapper

import numpy as np

from ...manager import BaseManager
from .torch_wrapper import PaillierTensor


[docs]def attach_paillier_to_client_for_encrypted_grad(cls, pk, sk): """Makes the client class communicate the encrypted gradients with paillier encryption scheme. Args: cls: client class pk: public key sk: secret key """ class PaillierClientWrapper(cls): def __init__(self, *args, **kwargs): super(PaillierClientWrapper, self).__init__(*args, **kwargs) def upload_gradients(self): """Uploads encrypted gradients""" pt_grads = super().upload_gradients() return [ PaillierTensor( np.vectorize(lambda x: pk.encrypt(x))(grad.detach().numpy()) ) for grad in pt_grads ] def download(self, global_grad): """Downloads and decrypt the received global gradients""" if not self.initialized: # initial parameters are not encrypted return super().download(global_grad) else: return super().download(self.decrypt_grad(global_grad)) def decrypt_grad(self, global_grad): decrypted_global_grad = [] for grad in global_grad: if type(grad) == PaillierTensor: decrypted_global_grad.append(grad.decrypt(sk, self.device)) else: decrypted_global_grad.append(grad) return decrypted_global_grad return PaillierClientWrapper
[docs]class PaillierGradientClientManager(BaseManager): """Client Manager for secure aggregation with Paillier Encryption"""
[docs] def attach(self, cls): return attach_paillier_to_client_for_encrypted_grad( cls, *self.args, **self.kwargs )