Source code for aijack.attack.membership.membership_inference

from ..base_attack import BaseAttacker
from .utils import AttackerModel, ShadowModels


[docs]class ShadowMembershipInferenceAttack(BaseAttacker): def __init__( self, target_model, shadow_models, attack_models, ): """Implementation of membership inference reference https://arxiv.org/abs/1610.05820 Args: target_model: the model of the victim shadow_models: shadow model for attack attack_models: attacker model for attack """ super().__init__(target_model) self.sm = ShadowModels(shadow_models) self.am = AttackerModel(attack_models) self.shadow_result = None
[docs] def fit(self, X, y): self.train_shadow(X, y) self.train_attacker()
[docs] def train_shadow(self, X, y): """train shadow models Args: X (np.array): training data for shadow models y (np.array): training label for shadow models """ self.shadow_result = self.sm.fit_transform(X, y)
[docs] def train_attacker(self): """Train attacker models""" self.am.fit(self.shadow_result)
[docs] def attack(self, x, y, proba=False): """Attack victim model Args: x: target datasets which the attacker wants to classify y: target labels which the attacker wants to classify proba: the format of the output """ prediction_of_taregt_model = self.target_model.predict_proba(x) if proba: return self.predit_proba(prediction_of_taregt_model, y) else: return self.predit(prediction_of_taregt_model, y)
[docs] def predict(self, pred, label): """Predict whether the given prediction came from training data or not Args: pred (torch.Tensor): predicted probabilities on the data label (torch.Tensor): true label of the data which y_pred_prob is predicted on Returns: predicted binaru labels """ return self.am.predict(pred, label)
[docs] def predict_proba(self, pred, label): """get probabilities of whether the given prediction came from training data or not Args: pred (torch.Tensor): predicted probabilities on the data label (torch.Tensor): true label of the data which y_pred_prob is predicted on Returns: predicted probabilities """ return self.am.predict_proba(pred, label)