Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
ce15e781
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce15e781
编写于
8月 24, 2020
作者:
L
liuluobin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Append the parameter verification of class MembershipInference.
上级
29e303a8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
48 addition
and
8 deletion
+48
-8
example/membership_inference_demo/train.py
example/membership_inference_demo/train.py
+3
-2
mindarmour/diff_privacy/evaluation/attacker.py
mindarmour/diff_privacy/evaluation/attacker.py
+7
-3
mindarmour/diff_privacy/evaluation/membership_inference.py
mindarmour/diff_privacy/evaluation/membership_inference.py
+38
-3
未找到文件。
example/membership_inference_demo/train.py
浏览文件 @
ce15e781
...
...
@@ -27,7 +27,7 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_param_into_net
,
load_checkpoint
from
mindarmour.utils
import
LogUtil
...
...
@@ -187,12 +187,13 @@ if __name__ == '__main__':
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
,
loss_scale_manager
=
None
)
# checkpoint save
callbacks
=
[
LossMonitor
()]
if
args
.
rank_save_ckpt_flag
:
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
*
args
.
steps_per_epoch
,
keep_checkpoint_max
=
args
.
ckpt_save_max
)
ckpt_cb
=
ModelCheckpoint
(
config
=
ckpt_config
,
directory
=
args
.
outputs_dir
,
prefix
=
'{}'
.
format
(
args
.
rank
))
callbacks
=
ckpt_cb
callbacks
.
append
(
ckpt_cb
)
model
.
train
(
args
.
max_epoch
,
dataset
,
callbacks
=
callbacks
)
mindarmour/diff_privacy/evaluation/attacker.py
100644 → 100755
浏览文件 @
ce15e781
...
...
@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV
from
sklearn.model_selection
import
RandomizedSearchCV
method_list
=
[
"lr"
,
"knn"
,
"rf"
,
"mlp"
]
def
_attack_knn
(
features
,
labels
,
param_grid
):
"""
Train and return a KNN model.
...
...
@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config):
sklearn.BaseEstimator, trained model specify by config["method"].
"""
method
=
str
.
lower
(
config
[
"method"
])
if
method
==
"knn"
:
return
_attack_knn
(
features
,
labels
,
config
[
"params"
])
if
method
in
[
"lr"
,
"logitic regression"
]
:
if
method
==
"lr"
:
return
_attack_lr
(
features
,
labels
,
config
[
"params"
])
if
method
==
"mlp"
:
return
_attack_mlpc
(
features
,
labels
,
config
[
"params"
])
if
method
in
[
"rf"
,
"random forest"
]
:
if
method
==
"rf"
:
return
_attack_rf
(
features
,
labels
,
config
[
"params"
])
r
aise
ValueError
(
"Method {} is not support."
.
format
(
config
[
"method"
]))
r
eturn
None
mindarmour/diff_privacy/evaluation/membership_inference.py
100644 → 100755
浏览文件 @
ce15e781
...
...
@@ -19,10 +19,11 @@ import numpy as np
import
mindspore
as
ms
from
mindspore.train
import
Model
from
mindspore.dataset.engine
import
Dataset
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
,
method_list
def
_eval_info
(
pred
,
truth
,
option
):
"""
...
...
@@ -89,7 +90,7 @@ class MembershipInference:
def
__init__
(
self
,
model
):
if
not
isinstance
(
model
,
Model
):
raise
TypeError
(
"Type of
model must be {}, but got {}."
.
format
(
type
(
Model
),
type
(
model
)))
raise
TypeError
(
"Type of
parameter 'model' must be Model, but got {}."
.
format
(
type
(
model
)))
self
.
model
=
model
self
.
attack_list
=
[]
...
...
@@ -104,8 +105,24 @@ class MembershipInference:
attack_config (list): Parameter setting for the attack model.
Raises:
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"].
KeyError: If each config in attack_config doesn't have keys {"method", "params"}
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
for
config
in
attack_config
:
if
{
"params"
,
"method"
}
!=
set
(
config
.
keys
()):
raise
KeyError
(
"Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}."
.
format
(
set
(
config
.
keys
())))
if
str
.
lower
(
config
[
"method"
])
not
in
method_list
:
raise
ValueError
(
"Method {} is not support."
.
format
(
config
[
"method"
]))
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
config
in
attack_config
:
self
.
attack_list
.
append
(
get_attack_model
(
features
,
labels
,
config
))
...
...
@@ -124,6 +141,24 @@ class MembershipInference:
Returns:
list, Each element contains an evaluation indicator for the attack model.
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
metrics
,
(
list
,
tuple
)):
raise
TypeError
(
"Type of parameter 'config' must be Union[list, tuple], but got "
"{}."
.
format
(
type
(
metrics
)))
metrics
=
set
(
metrics
)
metrics_list
=
{
"precision"
,
"accruacy"
,
"recall"
}
if
metrics
>
metrics_list
:
raise
ValueError
(
"Element in 'metrics' must be in {}, but got "
"{}."
.
format
(
metrics_list
,
metrics
))
result
=
[]
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
attacker
in
self
.
attack_list
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录