Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
dfb30f04
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfb30f04
编写于
8月 27, 2020
作者:
L
liuluobin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed exception detection and added log printing
上级
dcc64ddd
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
26 deletion
+62
-26
mindarmour/diff_privacy/evaluation/attacker.py
mindarmour/diff_privacy/evaluation/attacker.py
+9
-1
mindarmour/diff_privacy/evaluation/membership_inference.py
mindarmour/diff_privacy/evaluation/membership_inference.py
+53
-25
未找到文件。
mindarmour/diff_privacy/evaluation/attacker.py
浏览文件 @
dfb30f04
...
@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier
...
@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier
from
sklearn.model_selection
import
GridSearchCV
from
sklearn.model_selection
import
GridSearchCV
from
sklearn.model_selection
import
RandomizedSearchCV
from
sklearn.model_selection
import
RandomizedSearchCV
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
"Attacker"
def
_attack_knn
(
features
,
labels
,
param_grid
):
def
_attack_knn
(
features
,
labels
,
param_grid
):
"""
"""
...
@@ -138,4 +143,7 @@ def get_attack_model(features, labels, config):
...
@@ -138,4 +143,7 @@ def get_attack_model(features, labels, config):
return
_attack_mlpc
(
features
,
labels
,
config
[
"params"
])
return
_attack_mlpc
(
features
,
labels
,
config
[
"params"
])
if
method
==
"rf"
:
if
method
==
"rf"
:
return
_attack_rf
(
features
,
labels
,
config
[
"params"
])
return
_attack_rf
(
features
,
labels
,
config
[
"params"
])
return
None
msg
=
"Method {} is not supported."
.
format
(
config
[
"method"
])
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
mindarmour/diff_privacy/evaluation/membership_inference.py
浏览文件 @
dfb30f04
...
@@ -24,6 +24,11 @@ import mindspore.nn as nn
...
@@ -24,6 +24,11 @@ import mindspore.nn as nn
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
"MembershipInference"
def
_eval_info
(
pred
,
truth
,
option
):
def
_eval_info
(
pred
,
truth
,
option
):
"""
"""
...
@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option):
...
@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option):
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
"""
"""
if
pred
.
size
==
0
or
truth
.
size
==
0
:
if
pred
.
size
==
0
or
truth
.
size
==
0
:
raise
ValueError
(
"Size of pred or truth is 0."
)
msg
=
"Size of pred or truth is 0."
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
option
==
"accuracy"
:
if
option
==
"accuracy"
:
count
=
np
.
sum
(
pred
==
truth
)
count
=
np
.
sum
(
pred
==
truth
)
...
@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option):
...
@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option):
return
-
1
return
-
1
return
count
/
np
.
sum
(
truth
)
return
count
/
np
.
sum
(
truth
)
raise
ValueError
(
"The metric value {} is undefined."
.
format
(
option
))
msg
=
"The metric value {} is undefined."
.
format
(
option
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
class
MembershipInference
:
class
MembershipInference
:
...
@@ -91,7 +100,10 @@ class MembershipInference:
...
@@ -91,7 +100,10 @@ class MembershipInference:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
if
not
isinstance
(
model
,
Model
):
if
not
isinstance
(
model
,
Model
):
raise
TypeError
(
"Type of parameter 'model' must be Model, but got {}."
.
format
(
type
(
model
)))
msg
=
"Type of parameter 'model' must be Model, but got {}."
.
format
(
type
(
model
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
self
.
model
=
model
self
.
model
=
model
self
.
method_list
=
[
"knn"
,
"lr"
,
"mlp"
,
"rf"
]
self
.
method_list
=
[
"knn"
,
"lr"
,
"mlp"
,
"rf"
]
self
.
attack_list
=
[]
self
.
attack_list
=
[]
...
@@ -117,26 +129,34 @@ class MembershipInference:
...
@@ -117,26 +129,34 @@ class MembershipInference:
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
msg
=
"Type of parameter 'dataset_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
"but got {}"
.
format
(
type
(
dataset_train
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
dataset_test
,
Dataset
):
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
msg
=
"Type of parameter 'test_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
"but got {}"
.
format
(
type
(
dataset_train
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
attack_config
,
list
):
if
not
isinstance
(
attack_config
,
list
):
raise
TypeError
(
"Type of parameter 'attack_config' must be list, "
msg
=
"Type of parameter 'attack_config' must be list, but got {}."
.
format
(
type
(
attack_config
))
"but got {}."
.
format
(
type
(
attack_config
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
for
config
in
attack_config
:
for
config
in
attack_config
:
if
not
isinstance
(
config
,
dict
):
if
not
isinstance
(
config
,
dict
):
raise
TypeError
(
"Type of each config in 'attack_config' must be dict, "
msg
=
"Type of each config in 'attack_config' must be dict, but got {}."
.
format
(
type
(
config
))
"but got {}."
.
format
(
type
(
config
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
{
"params"
,
"method"
}
!=
set
(
config
.
keys
()):
if
{
"params"
,
"method"
}
!=
set
(
config
.
keys
()):
raise
KeyError
(
"Each config in attack_config must have keys 'method' and 'params', "
msg
=
"Each config in attack_config must have keys 'method' and 'params',"
\
"but your key value is {}."
.
format
(
set
(
config
.
keys
())))
"but your key value is {}."
.
format
(
set
(
config
.
keys
()))
LOGGER
.
error
(
TAG
,
msg
)
raise
KeyError
(
msg
)
if
str
.
lower
(
config
[
"method"
])
not
in
self
.
method_list
:
if
str
.
lower
(
config
[
"method"
])
not
in
self
.
method_list
:
raise
ValueError
(
"Method {} is not support."
.
format
(
config
[
"method"
]))
msg
=
"Method {} is not support."
.
format
(
config
[
"method"
])
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
config
in
attack_config
:
for
config
in
attack_config
:
...
@@ -157,22 +177,26 @@ class MembershipInference:
...
@@ -157,22 +177,26 @@ class MembershipInference:
list, Each element contains an evaluation indicator for the attack model.
list, Each element contains an evaluation indicator for the attack model.
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
msg
=
"Type of parameter 'dataset_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
"but got {}"
.
format
(
type
(
dataset_train
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
dataset_test
,
Dataset
):
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
msg
=
"Type of parameter 'test_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
"but got {}"
.
format
(
type
(
dataset_train
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
metrics
,
(
list
,
tuple
)):
if
not
isinstance
(
metrics
,
(
list
,
tuple
)):
raise
TypeError
(
"Type of parameter 'config' must be Union[list, tuple], but got "
msg
=
"Type of parameter 'config' must be Union[list, tuple], but got {}."
.
format
(
type
(
metrics
))
"{}."
.
format
(
type
(
metrics
)))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
metrics
=
set
(
metrics
)
metrics
=
set
(
metrics
)
metrics_list
=
{
"precision"
,
"accruacy"
,
"recall"
}
metrics_list
=
{
"precision"
,
"accruacy"
,
"recall"
}
if
metrics
>
metrics_list
:
if
metrics
>
metrics_list
:
raise
ValueError
(
"Element in 'metrics' must be in {}, but got "
msg
=
"Element in 'metrics' must be in {}, but got {}."
.
format
(
metrics_list
,
metrics
)
"{}."
.
format
(
metrics_list
,
metrics
))
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
result
=
[]
result
=
[]
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
...
@@ -221,8 +245,10 @@ class MembershipInference:
...
@@ -221,8 +245,10 @@ class MembershipInference:
- numpy.ndarray, Labels for each sample, Shape is (N,).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
"""
if
context
.
get_context
(
"device_target"
)
!=
"Ascend"
:
if
context
.
get_context
(
"device_target"
)
!=
"Ascend"
:
raise
RuntimeError
(
"The target device must be Ascend, "
msg
=
"The target device must be Ascend, "
\
"but current is {}."
.
format
(
context
.
get_context
(
"device_target"
)))
"but current is {}."
.
format
(
context
.
get_context
(
"device_target"
))
LOGGER
.
error
(
TAG
,
msg
)
raise
RuntimeError
(
msg
)
loss_logits
=
np
.
array
([])
loss_logits
=
np
.
array
([])
for
batch
in
dataset_x
.
create_dict_iterator
():
for
batch
in
dataset_x
.
create_dict_iterator
():
batch_data
=
Tensor
(
batch
[
'image'
],
ms
.
float32
)
batch_data
=
Tensor
(
batch
[
'image'
],
ms
.
float32
)
...
@@ -243,5 +269,7 @@ class MembershipInference:
...
@@ -243,5 +269,7 @@ class MembershipInference:
elif
label
==
0
:
elif
label
==
0
:
labels
=
np
.
zeros
(
len
(
loss_logits
),
np
.
int32
)
labels
=
np
.
zeros
(
len
(
loss_logits
),
np
.
int32
)
else
:
else
:
raise
ValueError
(
"The value of label must be 0 or 1, but got {}."
.
format
(
label
))
msg
=
"The value of label must be 0 or 1, but got {}."
.
format
(
label
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
return
loss_logits
,
labels
return
loss_logits
,
labels
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录