提交 7b21a3ff 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!101 fixed exception detection

Merge pull request !101 from liuluobin/master
......@@ -20,8 +20,6 @@ import numpy as np
import mindspore as ms
from mindspore.train import Model
from mindspore.dataset.engine import Dataset
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil
......@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option):
raise ValueError(msg)
def _softmax_cross_entropy(logits, labels):
"""
Calculate the SoftmaxCrossEntropy result between logits and labels.
Args:
logits (numpy.ndarray): Numpy array of shape(N, C).
labels (numpy.ndarray): Numpy array of shape(N, )
Returns:
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits.
"""
labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)
class MembershipInference:
"""
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
......@@ -192,8 +206,8 @@ class MembershipInference:
raise TypeError(msg)
metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list:
metrics_list = {"precision", "accuracy", "recall"}
if not metrics <= metrics_list:
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics)
LOGGER.error(TAG, msg)
raise ValueError(msg)
......@@ -244,19 +258,12 @@ class MembershipInference:
N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
if context.get_context("device_target") != "Ascend":
msg = "The target device must be Ascend, " \
"but current is {}.".format(context.get_context("device_target"))
LOGGER.error(TAG, msg)
raise RuntimeError(msg)
loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32)
batch_labels = Tensor(batch['label'], ms.int32)
batch_logits = self.model.predict(batch_data)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None)
batch_loss = loss(batch_logits, batch_labels).asnumpy()
batch_logits = batch_logits.asnumpy()
batch_labels = batch['label'].astype(np.int32)
batch_logits = self.model.predict(batch_data).asnumpy()
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels)
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits))
if loss_logits.size == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册