Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
ce15e781
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce15e781
编写于
8月 24, 2020
作者:
L
liuluobin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Append the parameter verification of class MembershipInference.
上级
29e303a8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
48 addition
and
8 deletion
+48
-8
example/membership_inference_demo/train.py
example/membership_inference_demo/train.py
+3
-2
mindarmour/diff_privacy/evaluation/attacker.py
mindarmour/diff_privacy/evaluation/attacker.py
+7
-3
mindarmour/diff_privacy/evaluation/membership_inference.py
mindarmour/diff_privacy/evaluation/membership_inference.py
+38
-3
未找到文件。
example/membership_inference_demo/train.py
浏览文件 @
ce15e781
...
@@ -27,7 +27,7 @@ import mindspore.nn as nn
...
@@ -27,7 +27,7 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train.model
import
Model
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_param_into_net
,
load_checkpoint
from
mindspore.train.serialization
import
load_param_into_net
,
load_checkpoint
from
mindarmour.utils
import
LogUtil
from
mindarmour.utils
import
LogUtil
...
@@ -187,12 +187,13 @@ if __name__ == '__main__':
...
@@ -187,12 +187,13 @@ if __name__ == '__main__':
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
,
loss_scale_manager
=
None
)
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
,
loss_scale_manager
=
None
)
# checkpoint save
# checkpoint save
callbacks
=
[
LossMonitor
()]
if
args
.
rank_save_ckpt_flag
:
if
args
.
rank_save_ckpt_flag
:
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
*
args
.
steps_per_epoch
,
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
*
args
.
steps_per_epoch
,
keep_checkpoint_max
=
args
.
ckpt_save_max
)
keep_checkpoint_max
=
args
.
ckpt_save_max
)
ckpt_cb
=
ModelCheckpoint
(
config
=
ckpt_config
,
ckpt_cb
=
ModelCheckpoint
(
config
=
ckpt_config
,
directory
=
args
.
outputs_dir
,
directory
=
args
.
outputs_dir
,
prefix
=
'{}'
.
format
(
args
.
rank
))
prefix
=
'{}'
.
format
(
args
.
rank
))
callbacks
=
ckpt_cb
callbacks
.
append
(
ckpt_cb
)
model
.
train
(
args
.
max_epoch
,
dataset
,
callbacks
=
callbacks
)
model
.
train
(
args
.
max_epoch
,
dataset
,
callbacks
=
callbacks
)
mindarmour/diff_privacy/evaluation/attacker.py
100644 → 100755
浏览文件 @
ce15e781
...
@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV
...
@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV
from
sklearn.model_selection
import
RandomizedSearchCV
from
sklearn.model_selection
import
RandomizedSearchCV
method_list
=
[
"lr"
,
"knn"
,
"rf"
,
"mlp"
]
def
_attack_knn
(
features
,
labels
,
param_grid
):
def
_attack_knn
(
features
,
labels
,
param_grid
):
"""
"""
Train and return a KNN model.
Train and return a KNN model.
...
@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config):
...
@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config):
sklearn.BaseEstimator, trained model specify by config["method"].
sklearn.BaseEstimator, trained model specify by config["method"].
"""
"""
method
=
str
.
lower
(
config
[
"method"
])
method
=
str
.
lower
(
config
[
"method"
])
if
method
==
"knn"
:
if
method
==
"knn"
:
return
_attack_knn
(
features
,
labels
,
config
[
"params"
])
return
_attack_knn
(
features
,
labels
,
config
[
"params"
])
if
method
in
[
"lr"
,
"logitic regression"
]
:
if
method
==
"lr"
:
return
_attack_lr
(
features
,
labels
,
config
[
"params"
])
return
_attack_lr
(
features
,
labels
,
config
[
"params"
])
if
method
==
"mlp"
:
if
method
==
"mlp"
:
return
_attack_mlpc
(
features
,
labels
,
config
[
"params"
])
return
_attack_mlpc
(
features
,
labels
,
config
[
"params"
])
if
method
in
[
"rf"
,
"random forest"
]
:
if
method
==
"rf"
:
return
_attack_rf
(
features
,
labels
,
config
[
"params"
])
return
_attack_rf
(
features
,
labels
,
config
[
"params"
])
r
aise
ValueError
(
"Method {} is not support."
.
format
(
config
[
"method"
]))
r
eturn
None
mindarmour/diff_privacy/evaluation/membership_inference.py
100644 → 100755
浏览文件 @
ce15e781
...
@@ -19,10 +19,11 @@ import numpy as np
...
@@ -19,10 +19,11 @@ import numpy as np
import
mindspore
as
ms
import
mindspore
as
ms
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.dataset.engine
import
Dataset
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
from
mindarmour.diff_privacy.evaluation.attacker
import
get_attack_model
,
method_list
def
_eval_info
(
pred
,
truth
,
option
):
def
_eval_info
(
pred
,
truth
,
option
):
"""
"""
...
@@ -89,7 +90,7 @@ class MembershipInference:
...
@@ -89,7 +90,7 @@ class MembershipInference:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
if
not
isinstance
(
model
,
Model
):
if
not
isinstance
(
model
,
Model
):
raise
TypeError
(
"Type of
model must be {}, but got {}."
.
format
(
type
(
Model
),
type
(
model
)))
raise
TypeError
(
"Type of
parameter 'model' must be Model, but got {}."
.
format
(
type
(
model
)))
self
.
model
=
model
self
.
model
=
model
self
.
attack_list
=
[]
self
.
attack_list
=
[]
...
@@ -104,8 +105,24 @@ class MembershipInference:
...
@@ -104,8 +105,24 @@ class MembershipInference:
attack_config (list): Parameter setting for the attack model.
attack_config (list): Parameter setting for the attack model.
Raises:
Raises:
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"].
KeyError: If each config in attack_config doesn't have keys {"method", "params"}
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
for
config
in
attack_config
:
if
{
"params"
,
"method"
}
!=
set
(
config
.
keys
()):
raise
KeyError
(
"Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}."
.
format
(
set
(
config
.
keys
())))
if
str
.
lower
(
config
[
"method"
])
not
in
method_list
:
raise
ValueError
(
"Method {} is not support."
.
format
(
config
[
"method"
]))
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
config
in
attack_config
:
for
config
in
attack_config
:
self
.
attack_list
.
append
(
get_attack_model
(
features
,
labels
,
config
))
self
.
attack_list
.
append
(
get_attack_model
(
features
,
labels
,
config
))
...
@@ -124,6 +141,24 @@ class MembershipInference:
...
@@ -124,6 +141,24 @@ class MembershipInference:
Returns:
Returns:
list, Each element contains an evaluation indicator for the attack model.
list, Each element contains an evaluation indicator for the attack model.
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
raise
TypeError
(
"Type of parameter 'dataset_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
dataset_test
,
Dataset
):
raise
TypeError
(
"Type of parameter 'test_train' must be Dataset, "
"but got {}"
.
format
(
type
(
dataset_train
)))
if
not
isinstance
(
metrics
,
(
list
,
tuple
)):
raise
TypeError
(
"Type of parameter 'config' must be Union[list, tuple], but got "
"{}."
.
format
(
type
(
metrics
)))
metrics
=
set
(
metrics
)
metrics_list
=
{
"precision"
,
"accruacy"
,
"recall"
}
if
metrics
>
metrics_list
:
raise
ValueError
(
"Element in 'metrics' must be in {}, but got "
"{}."
.
format
(
metrics_list
,
metrics
))
result
=
[]
result
=
[]
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
attacker
in
self
.
attack_list
:
for
attacker
in
self
.
attack_list
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录