提交 0921714a 编写于 作者: L liuluobin

Fix param check of metric

上级 3a3ff173
...@@ -20,8 +20,6 @@ import numpy as np ...@@ -20,8 +20,6 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.train import Model from mindspore.train import Model
from mindspore.dataset.engine import Dataset from mindspore.dataset.engine import Dataset
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
...@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option): ...@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option):
raise ValueError(msg) raise ValueError(msg)
def _softmax_cross_entropy(logits, labels):
"""
Calculate the SoftmaxCrossEntropy result between logits and labels.
Args:
logits (numpy.ndarray): Numpy array of shape(N, C).
labels (numpy.ndarray): Numpy array of shape(N, )
Returns:
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits.
"""
labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)
class MembershipInference: class MembershipInference:
""" """
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
...@@ -192,8 +206,8 @@ class MembershipInference: ...@@ -192,8 +206,8 @@ class MembershipInference:
raise TypeError(msg) raise TypeError(msg)
metrics = set(metrics) metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"} metrics_list = {"precision", "accuracy", "recall"}
if metrics > metrics_list: if not metrics <= metrics_list:
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
...@@ -244,19 +258,12 @@ class MembershipInference: ...@@ -244,19 +258,12 @@ class MembershipInference:
N is the number of sample. C = 1 + dim(logits). N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,). - numpy.ndarray, Labels for each sample, Shape is (N,).
""" """
if context.get_context("device_target") != "Ascend":
msg = "The target device must be Ascend, " \
"but current is {}.".format(context.get_context("device_target"))
LOGGER.error(TAG, msg)
raise RuntimeError(msg)
loss_logits = np.array([]) loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator(): for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32) batch_data = Tensor(batch['image'], ms.float32)
batch_labels = Tensor(batch['label'], ms.int32) batch_labels = batch['label'].astype(np.int32)
batch_logits = self.model.predict(batch_data) batch_logits = self.model.predict(batch_data).asnumpy()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) batch_loss = _softmax_cross_entropy(batch_logits, batch_labels)
batch_loss = loss(batch_logits, batch_labels).asnumpy()
batch_logits = batch_logits.asnumpy()
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits))
if loss_logits.size == 0: if loss_logits.size == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册