diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py index 9378fa54dd36ed0a68fa150486aaea5b32d26b0a..5e7f152b848961f1c093afaad6449ec99d909836 100755 --- a/mindarmour/diff_privacy/evaluation/membership_inference.py +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -20,8 +20,6 @@ import numpy as np import mindspore as ms from mindspore.train import Model from mindspore.dataset.engine import Dataset -import mindspore.nn as nn -import mindspore.context as context from mindspore import Tensor from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.utils.logger import LogUtil @@ -71,6 +69,22 @@ def _eval_info(pred, truth, option): raise ValueError(msg) +def _softmax_cross_entropy(logits, labels): + """ + Calculate the SoftmaxCrossEntropy result between logits and labels. + + Args: + logits (numpy.ndarray): Numpy array of shape(N, C). + labels (numpy.ndarray): Numpy array of shape(N, ) + + Returns: + numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. + """ + labels = np.eye(logits.shape[1])[labels].astype(np.int32) + logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + return -1*np.sum(labels*np.log(logits), axis=1) + + class MembershipInference: """ Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. @@ -192,8 +206,8 @@ class MembershipInference: raise TypeError(msg) metrics = set(metrics) - metrics_list = {"precision", "accruacy", "recall"} - if metrics > metrics_list: + metrics_list = {"precision", "accuracy", "recall"} + if not metrics <= metrics_list: msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -244,19 +258,12 @@ class MembershipInference: N is the number of sample. C = 1 + dim(logits). - numpy.ndarray, Labels for each sample, Shape is (N,). """ - if context.get_context("device_target") != "Ascend": - msg = "The target device must be Ascend, " \ - "but current is {}.".format(context.get_context("device_target")) - LOGGER.error(TAG, msg) - raise RuntimeError(msg) loss_logits = np.array([]) for batch in dataset_x.create_dict_iterator(): batch_data = Tensor(batch['image'], ms.float32) - batch_labels = Tensor(batch['label'], ms.int32) - batch_logits = self.model.predict(batch_data) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) - batch_loss = loss(batch_logits, batch_labels).asnumpy() - batch_logits = batch_logits.asnumpy() + batch_labels = batch['label'].astype(np.int32) + batch_logits = self.model.predict(batch_data).asnumpy() + batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) if loss_logits.size == 0: