Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
0921714a
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0921714a
编写于
8月 27, 2020
作者:
L
liuluobin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix param check of metric
上级
3a3ff173
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
14 deletion
+21
-14
mindarmour/diff_privacy/evaluation/membership_inference.py
mindarmour/diff_privacy/evaluation/membership_inference.py
+21
-14
未找到文件。
mindarmour/diff_privacy/evaluation/membership_inference.py
浏览文件 @
0921714a
...
...
@@ -20,8 +20,6 @@ import numpy as np
import
mindspore
as
ms
from
mindspore.train
import
Model
from
mindspore.dataset.engine
import
Dataset
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
from
mindarmour.utils.logger
import
LogUtil
...
...
@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option):
raise
ValueError
(
msg
)
def
_softmax_cross_entropy
(
logits
,
labels
):
"""
Calculate the SoftmaxCrossEntropy result between logits and labels.
Args:
logits (numpy.ndarray): Numpy array of shape(N, C).
labels (numpy.ndarray): Numpy array of shape(N, )
Returns:
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits.
"""
labels
=
np
.
eye
(
logits
.
shape
[
1
])[
labels
].
astype
(
np
.
int32
)
logits
=
np
.
exp
(
logits
)
/
np
.
sum
(
np
.
exp
(
logits
),
axis
=
1
,
keepdims
=
True
)
return
-
1
*
np
.
sum
(
labels
*
np
.
log
(
logits
),
axis
=
1
)
class
MembershipInference
:
"""
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
...
...
@@ -192,8 +206,8 @@ class MembershipInference:
raise
TypeError
(
msg
)
metrics
=
set
(
metrics
)
metrics_list
=
{
"precision"
,
"acc
ru
acy"
,
"recall"
}
if
metrics
>
metrics_list
:
metrics_list
=
{
"precision"
,
"acc
ur
acy"
,
"recall"
}
if
not
metrics
<=
metrics_list
:
msg
=
"Element in 'metrics' must be in {}, but got {}."
.
format
(
metrics_list
,
metrics
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
...
...
@@ -244,19 +258,12 @@ class MembershipInference:
N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
if
context
.
get_context
(
"device_target"
)
!=
"Ascend"
:
msg
=
"The target device must be Ascend, "
\
"but current is {}."
.
format
(
context
.
get_context
(
"device_target"
))
LOGGER
.
error
(
TAG
,
msg
)
raise
RuntimeError
(
msg
)
loss_logits
=
np
.
array
([])
for
batch
in
dataset_x
.
create_dict_iterator
():
batch_data
=
Tensor
(
batch
[
'image'
],
ms
.
float32
)
batch_labels
=
Tensor
(
batch
[
'label'
],
ms
.
int32
)
batch_logits
=
self
.
model
.
predict
(
batch_data
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
is_grad
=
False
,
reduction
=
None
)
batch_loss
=
loss
(
batch_logits
,
batch_labels
).
asnumpy
()
batch_logits
=
batch_logits
.
asnumpy
()
batch_labels
=
batch
[
'label'
].
astype
(
np
.
int32
)
batch_logits
=
self
.
model
.
predict
(
batch_data
).
asnumpy
()
batch_loss
=
_softmax_cross_entropy
(
batch_logits
,
batch_labels
)
batch_feature
=
np
.
hstack
((
batch_loss
.
reshape
(
-
1
,
1
),
batch_logits
))
if
loss_logits
.
size
==
0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录