Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
604eeb97
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
604eeb97
编写于
9月 09, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
!110 Correct some docs error. Modify the type detection code.
Merge pull request !110 from liuluobin/master
上级
8b142a22
2ded64d6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
63 addition
and
108 deletion
+63
-108
mindarmour/privacy/evaluation/_check_config.py
mindarmour/privacy/evaluation/_check_config.py
+17
-30
mindarmour/privacy/evaluation/attacker.py
mindarmour/privacy/evaluation/attacker.py
+1
-1
mindarmour/privacy/evaluation/membership_inference.py
mindarmour/privacy/evaluation/membership_inference.py
+35
-67
tests/ut/python/diff_privacy/test_attacker.py
tests/ut/python/diff_privacy/test_attacker.py
+10
-10
未找到文件。
mindarmour/privacy/evaluation/_check_config.py
浏览文件 @
604eeb97
...
@@ -15,11 +15,12 @@
...
@@ -15,11 +15,12 @@
Verify attack config
Verify attack config
"""
"""
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
"check_
params
"
TAG
=
"check_
config
"
def
_is_positive_int
(
item
):
def
_is_positive_int
(
item
):
...
@@ -77,7 +78,7 @@ def _is_dict(item):
...
@@ -77,7 +78,7 @@ def _is_dict(item):
return
isinstance
(
item
,
dict
)
return
isinstance
(
item
,
dict
)
VALID_PARAMS_DIC
T
=
{
_VALID_CONFIG_CHECKLIS
T
=
{
"knn"
:
{
"knn"
:
{
"n_neighbors"
:
[
_is_positive_int
],
"n_neighbors"
:
[
_is_positive_int
],
"weights"
:
[{
"uniform"
,
"distance"
}],
"weights"
:
[{
"uniform"
,
"distance"
}],
...
@@ -126,7 +127,7 @@ VALID_PARAMS_DICT = {
...
@@ -126,7 +127,7 @@ VALID_PARAMS_DICT = {
"rf"
:
{
"rf"
:
{
"n_estimators"
:
[
_is_positive_int
],
"n_estimators"
:
[
_is_positive_int
],
"criterion"
:
[{
"gini"
,
"entropy"
}],
"criterion"
:
[{
"gini"
,
"entropy"
}],
"max_depth"
:
[
_is_positive_int
],
"max_depth"
:
[
{
None
},
_is_positive_int
],
"min_samples_split"
:
[
_is_positive_float
],
"min_samples_split"
:
[
_is_positive_float
],
"min_samples_leaf"
:
[
_is_positive_float
],
"min_samples_leaf"
:
[
_is_positive_float
],
"min_weight_fraction_leaf"
:
[
_is_non_negative_float
],
"min_weight_fraction_leaf"
:
[
_is_non_negative_float
],
...
@@ -148,24 +149,15 @@ VALID_PARAMS_DICT = {
...
@@ -148,24 +149,15 @@ VALID_PARAMS_DICT = {
def
_check_config
(
config_list
,
check_params
):
def
_check_config
(
attack_config
,
config_checklist
):
"""
"""
Verify that config_list is valid.
Verify that config_list is valid.
Check_params is the valid value range of the parameter.
Check_params is the valid value range of the parameter.
"""
"""
if
not
isinstance
(
config_list
,
(
list
,
tuple
)):
for
config
in
attack_config
:
msg
=
"Type of parameter 'config_list' must be list, but got {}."
.
format
(
type
(
config_list
))
check_param_type
(
"config"
,
config
,
dict
)
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
for
config
in
config_list
:
if
not
isinstance
(
config
,
dict
):
msg
=
"Type of each config in config_list must be dict, but got {}."
.
format
(
type
(
config
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
set
(
config
.
keys
())
!=
{
"params"
,
"method"
}:
if
set
(
config
.
keys
())
!=
{
"params"
,
"method"
}:
msg
=
"Keys of each config in
config_list
must be {},"
\
msg
=
"Keys of each config in
attack_config
must be {},"
\
"but got {}."
.
format
({
'method'
,
'params'
},
set
(
config
.
keys
()))
"but got {}."
.
format
({
'method'
,
'params'
},
set
(
config
.
keys
()))
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
KeyError
(
msg
)
raise
KeyError
(
msg
)
...
@@ -173,27 +165,22 @@ def _check_config(config_list, check_params):
...
@@ -173,27 +165,22 @@ def _check_config(config_list, check_params):
method
=
str
.
lower
(
config
[
"method"
])
method
=
str
.
lower
(
config
[
"method"
])
params
=
config
[
"params"
]
params
=
config
[
"params"
]
if
method
not
in
c
heck_params
.
keys
():
if
method
not
in
c
onfig_checklist
.
keys
():
msg
=
"Method {} is not supported."
.
format
(
method
)
msg
=
"Method {} is not supported."
.
format
(
method
)
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
Valu
eError
(
msg
)
raise
Nam
eError
(
msg
)
if
not
params
.
keys
()
<=
c
heck_params
[
method
].
keys
():
if
not
params
.
keys
()
<=
c
onfig_checklist
[
method
].
keys
():
msg
=
"Params in method {} is not accepted, the parameters "
\
msg
=
"Params in method {} is not accepted, the parameters "
\
"that can be set are {}."
.
format
(
method
,
set
(
c
heck_params
[
method
].
keys
()))
"that can be set are {}."
.
format
(
method
,
set
(
c
onfig_checklist
[
method
].
keys
()))
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
KeyError
(
msg
)
raise
KeyError
(
msg
)
for
param_key
in
params
.
keys
():
for
param_key
in
params
.
keys
():
param_value
=
params
[
param_key
]
param_value
=
params
[
param_key
]
candidate_values
=
check_params
[
method
][
param_key
]
candidate_values
=
config_checklist
[
method
][
param_key
]
check_param_type
(
'param_value'
,
param_value
,
list
)
if
not
isinstance
(
param_value
,
list
):
msg
=
"The parameter '{}' in method '{}' setting must within the range of "
\
"changeable parameters."
.
format
(
param_key
,
method
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
candidate_values
is
None
:
if
candidate_values
is
None
:
continue
continue
...
@@ -204,7 +191,7 @@ def _check_config(config_list, check_params):
...
@@ -204,7 +191,7 @@ def _check_config(config_list, check_params):
if
isinstance
(
candidate_value
,
set
)
and
item_value
in
candidate_value
:
if
isinstance
(
candidate_value
,
set
)
and
item_value
in
candidate_value
:
flag
=
True
flag
=
True
break
break
elif
candidate_value
(
item_value
):
elif
not
isinstance
(
candidate_value
,
set
)
and
candidate_value
(
item_value
):
flag
=
True
flag
=
True
break
break
...
@@ -213,8 +200,8 @@ def _check_config(config_list, check_params):
...
@@ -213,8 +200,8 @@ def _check_config(config_list, check_params):
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
def
check_config_params
(
config_list
):
def
verify_config_params
(
attack_config
):
"""
"""
External interfaces to verify attack config.
External interfaces to verify attack config.
"""
"""
_check_config
(
config_list
,
VALID_PARAMS_DIC
T
)
_check_config
(
attack_config
,
_VALID_CONFIG_CHECKLIS
T
)
mindarmour/privacy/evaluation/attacker.py
浏览文件 @
604eeb97
...
@@ -153,4 +153,4 @@ def get_attack_model(features, labels, config, n_jobs=-1):
...
@@ -153,4 +153,4 @@ def get_attack_model(features, labels, config, n_jobs=-1):
msg
=
"Method {} is not supported."
.
format
(
config
[
"method"
])
msg
=
"Method {} is not supported."
.
format
(
config
[
"method"
])
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
Valu
eError
(
msg
)
raise
Nam
eError
(
msg
)
mindarmour/privacy/evaluation/membership_inference.py
浏览文件 @
604eeb97
...
@@ -23,8 +23,10 @@ from mindspore.train import Model
...
@@ -23,8 +23,10 @@ from mindspore.train import Model
from
mindspore.dataset.engine
import
Dataset
from
mindspore.dataset.engine
import
Dataset
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils._check_param
import
check_param_type
,
check_param_multi_types
,
\
check_model
,
check_numpy_param
from
.attacker
import
get_attack_model
from
.attacker
import
get_attack_model
from
._check_config
import
check
_config_params
from
._check_config
import
verify
_config_params
LOGGER
=
LogUtil
.
get_instance
()
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
"MembershipInference"
TAG
=
"MembershipInference"
...
@@ -47,23 +49,21 @@ def _eval_info(pred, truth, option):
...
@@ -47,23 +49,21 @@ def _eval_info(pred, truth, option):
ValueError, size of parameter pred or truth is 0.
ValueError, size of parameter pred or truth is 0.
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
"""
"""
if
pred
.
size
==
0
or
truth
.
size
==
0
:
check_numpy_param
(
"pred"
,
pred
)
msg
=
"Size of pred or truth is 0."
check_numpy_param
(
"truth"
,
truth
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
option
==
"accuracy"
:
if
option
==
"accuracy"
:
count
=
np
.
sum
(
pred
==
truth
)
count
=
np
.
sum
(
pred
==
truth
)
return
count
/
len
(
pred
)
return
count
/
len
(
pred
)
if
option
==
"precision"
:
if
option
==
"precision"
:
count
=
np
.
sum
(
pred
&
truth
)
if
np
.
sum
(
pred
)
==
0
:
if
np
.
sum
(
pred
)
==
0
:
return
-
1
return
-
1
count
=
np
.
sum
(
pred
&
truth
)
return
count
/
np
.
sum
(
pred
)
return
count
/
np
.
sum
(
pred
)
if
option
==
"recall"
:
if
option
==
"recall"
:
count
=
np
.
sum
(
pred
&
truth
)
if
np
.
sum
(
truth
)
==
0
:
if
np
.
sum
(
truth
)
==
0
:
return
-
1
return
-
1
count
=
np
.
sum
(
pred
&
truth
)
return
count
/
np
.
sum
(
truth
)
return
count
/
np
.
sum
(
truth
)
msg
=
"The metric value {} is undefined."
.
format
(
option
)
msg
=
"The metric value {} is undefined."
.
format
(
option
)
...
@@ -107,9 +107,9 @@ class MembershipInference:
...
@@ -107,9 +107,9 @@ class MembershipInference:
otherwise the value of n_jobs must be a positive integer.
otherwise the value of n_jobs must be a positive integer.
Examples:
Examples:
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>>
#
train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>>
#
test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>>
#
We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model, n_jobs=-1)
>>> inference_model = MembershipInference(model, n_jobs=-1)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
...
@@ -124,65 +124,44 @@ class MembershipInference:
...
@@ -124,65 +124,44 @@ class MembershipInference:
"""
"""
def
__init__
(
self
,
model
,
n_jobs
=-
1
):
def
__init__
(
self
,
model
,
n_jobs
=-
1
):
if
not
isinstance
(
model
,
Model
):
check_param_type
(
"n_jobs"
,
n_jobs
,
int
)
msg
=
"Type of parameter 'model' must be Model, but got {}."
.
format
(
type
(
model
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
n_jobs
,
int
):
msg
=
"Type of parameter 'n_jobs' must be int, but got {}"
.
format
(
type
(
n_jobs
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
(
n_jobs
==
-
1
or
n_jobs
>
0
):
if
not
(
n_jobs
==
-
1
or
n_jobs
>
0
):
msg
=
"Value of n_jobs must be either -1 or positive integer, but got {}."
.
format
(
n_jobs
)
msg
=
"Value of n_jobs must be either -1 or positive integer, but got {}."
.
format
(
n_jobs
)
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
self
.
model
=
model
self
.
_model
=
check_model
(
"model"
,
model
,
Model
)
self
.
n_jobs
=
min
(
n_jobs
,
cpu_count
())
self
.
_n_jobs
=
min
(
n_jobs
,
cpu_count
())
self
.
method_list
=
[
"knn"
,
"lr"
,
"mlp"
,
"rf"
]
self
.
_attack_list
=
[]
self
.
attack_list
=
[]
def
train
(
self
,
dataset_train
,
dataset_test
,
attack_config
):
def
train
(
self
,
dataset_train
,
dataset_test
,
attack_config
):
"""
"""
Depending on the configuration, use the in
coming
data set to train the attack model.
Depending on the configuration, use the in
put
data set to train the attack model.
Save the attack model to self.attack_list.
Save the attack model to self.
_
attack_list.
Args:
Args:
dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test set for the target model.
dataset_test (mindspore.dataset): The test set for the target model.
attack_config (
list
): Parameter setting for the attack model. The format is
attack_config (
Union[list, tuple]
): Parameter setting for the attack model. The format is
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}].
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}].
The support methods
list is in self.method_list
, and the params of each method
The support methods
are knn, lr, mlp and rf
, and the params of each method
must within the range of changeable parameters. Tips of params implement
must within the range of changeable parameters. Tips of params implement
can be found in
can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Raises:
Raises:
KeyError: If
each
config in attack_config doesn't have keys {"method", "params"}
KeyError: If
any
config in attack_config doesn't have keys {"method", "params"}
Valu
eError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
Nam
eError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
check_param_type
(
"dataset_train"
,
dataset_train
,
Dataset
)
msg
=
"Type of parameter 'dataset_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
check_param_type
(
"dataset_test"
,
dataset_test
,
Dataset
)
LOGGER
.
error
(
TAG
,
msg
)
check_param_multi_types
(
"attack_config"
,
attack_config
,
(
list
,
tuple
))
raise
TypeError
(
msg
)
verify_config_params
(
attack_config
)
if
not
isinstance
(
dataset_test
,
Dataset
):
msg
=
"Type of parameter 'test_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
attack_config
,
list
):
msg
=
"Type of parameter 'attack_config' must be list, but got {}."
.
format
(
type
(
attack_config
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
check_config_params
(
attack_config
)
# Verify attack config.
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
config
in
attack_config
:
for
config
in
attack_config
:
self
.
attack_list
.
append
(
get_attack_model
(
features
,
labels
,
config
,
n_jobs
=
self
.
n_jobs
))
self
.
_attack_list
.
append
(
get_attack_model
(
features
,
labels
,
config
,
n_jobs
=
self
.
_
n_jobs
))
def
eval
(
self
,
dataset_train
,
dataset_test
,
metrics
):
def
eval
(
self
,
dataset_train
,
dataset_test
,
metrics
):
...
@@ -199,20 +178,9 @@ class MembershipInference:
...
@@ -199,20 +178,9 @@ class MembershipInference:
Returns:
Returns:
list, Each element contains an evaluation indicator for the attack model.
list, Each element contains an evaluation indicator for the attack model.
"""
"""
if
not
isinstance
(
dataset_train
,
Dataset
):
check_param_type
(
"dataset_train"
,
dataset_train
,
Dataset
)
msg
=
"Type of parameter 'dataset_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
check_param_type
(
"dataset_test"
,
dataset_test
,
Dataset
)
LOGGER
.
error
(
TAG
,
msg
)
check_param_multi_types
(
"metrics"
,
metrics
,
(
list
,
tuple
))
raise
TypeError
(
msg
)
if
not
isinstance
(
dataset_test
,
Dataset
):
msg
=
"Type of parameter 'test_train' must be Dataset, but got {}"
.
format
(
type
(
dataset_train
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
if
not
isinstance
(
metrics
,
(
list
,
tuple
)):
msg
=
"Type of parameter 'config' must be Union[list, tuple], but got {}."
.
format
(
type
(
metrics
))
LOGGER
.
error
(
TAG
,
msg
)
raise
TypeError
(
msg
)
metrics
=
set
(
metrics
)
metrics
=
set
(
metrics
)
metrics_list
=
{
"precision"
,
"accuracy"
,
"recall"
}
metrics_list
=
{
"precision"
,
"accuracy"
,
"recall"
}
...
@@ -223,7 +191,7 @@ class MembershipInference:
...
@@ -223,7 +191,7 @@ class MembershipInference:
result
=
[]
result
=
[]
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
features
,
labels
=
self
.
_transform
(
dataset_train
,
dataset_test
)
for
attacker
in
self
.
attack_list
:
for
attacker
in
self
.
_
attack_list
:
pred
=
attacker
.
predict
(
features
)
pred
=
attacker
.
predict
(
features
)
item
=
{}
item
=
{}
for
option
in
metrics
:
for
option
in
metrics
:
...
@@ -233,7 +201,7 @@ class MembershipInference:
...
@@ -233,7 +201,7 @@ class MembershipInference:
def
_transform
(
self
,
dataset_train
,
dataset_test
):
def
_transform
(
self
,
dataset_train
,
dataset_test
):
"""
"""
Generate corresponding loss_logits feature and new label, and return after shuffle.
Generate corresponding loss_logits feature
s
and new label, and return after shuffle.
Args:
Args:
dataset_train: The training set for the target model.
dataset_train: The training set for the target model.
...
@@ -255,13 +223,13 @@ class MembershipInference:
...
@@ -255,13 +223,13 @@ class MembershipInference:
return
features
,
labels
return
features
,
labels
def
_generate
(
self
,
dataset_x
,
label
):
def
_generate
(
self
,
input_dataset
,
label
):
"""
"""
Return a loss_logits features and labels for training attack model.
Return a loss_logits features and labels for training attack model.
Args:
Args:
dataset_x
(mindspore.dataset): The dataset to be generate.
input_dataset
(mindspore.dataset): The dataset to be generate.
label (int32): Whether
dataset_x
belongs to the target model.
label (int32): Whether
input_dataset
belongs to the target model.
Returns:
Returns:
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C).
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C).
...
@@ -269,10 +237,10 @@ class MembershipInference:
...
@@ -269,10 +237,10 @@ class MembershipInference:
- numpy.ndarray, Labels for each sample, Shape is (N,).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
"""
loss_logits
=
np
.
array
([])
loss_logits
=
np
.
array
([])
for
batch
in
dataset_x
.
create_dict_iterator
():
for
batch
in
input_dataset
.
create_dict_iterator
():
batch_data
=
Tensor
(
batch
[
'image'
],
ms
.
float32
)
batch_data
=
Tensor
(
batch
[
'image'
],
ms
.
float32
)
batch_labels
=
batch
[
'label'
].
astype
(
np
.
int32
)
batch_labels
=
batch
[
'label'
].
astype
(
np
.
int32
)
batch_logits
=
self
.
model
.
predict
(
batch_data
).
asnumpy
()
batch_logits
=
self
.
_
model
.
predict
(
batch_data
).
asnumpy
()
batch_loss
=
_softmax_cross_entropy
(
batch_logits
,
batch_labels
)
batch_loss
=
_softmax_cross_entropy
(
batch_logits
,
batch_labels
)
batch_feature
=
np
.
hstack
((
batch_loss
.
reshape
(
-
1
,
1
),
batch_logits
))
batch_feature
=
np
.
hstack
((
batch_loss
.
reshape
(
-
1
,
1
),
batch_logits
))
...
...
tests/ut/python/diff_privacy/test_attacker.py
浏览文件 @
604eeb97
...
@@ -27,12 +27,12 @@ from mindarmour.privacy.evaluation.attacker import get_attack_model
...
@@ -27,12 +27,12 @@ from mindarmour.privacy.evaluation.attacker import get_attack_model
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_get_knn_model
():
def
test_get_knn_model
():
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
,
10
])
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
0
,
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
0
])
config_knn
=
{
config_knn
=
{
"method"
:
"KNN"
,
"method"
:
"KNN"
,
"params"
:
{
"params"
:
{
"n_neighbors"
:
[
3
],
"n_neighbors"
:
[
3
,
5
,
7
],
}
}
}
}
knn_attacker
=
get_attack_model
(
features
,
labels
,
config_knn
,
-
1
)
knn_attacker
=
get_attack_model
(
features
,
labels
,
config_knn
,
-
1
)
...
@@ -46,8 +46,8 @@ def test_get_knn_model():
...
@@ -46,8 +46,8 @@ def test_get_knn_model():
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_get_lr_model
():
def
test_get_lr_model
():
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
,
10
])
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
0
,
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
0
])
config_lr
=
{
config_lr
=
{
"method"
:
"LR"
,
"method"
:
"LR"
,
"params"
:
{
"params"
:
{
...
@@ -65,8 +65,8 @@ def test_get_lr_model():
...
@@ -65,8 +65,8 @@ def test_get_lr_model():
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_get_mlp_model
():
def
test_get_mlp_model
():
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
,
10
])
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
0
,
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
0
])
config_mlpc
=
{
config_mlpc
=
{
"method"
:
"MLP"
,
"method"
:
"MLP"
,
"params"
:
{
"params"
:
{
...
@@ -86,14 +86,14 @@ def test_get_mlp_model():
...
@@ -86,14 +86,14 @@ def test_get_mlp_model():
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_get_rf_model
():
def
test_get_rf_model
():
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
,
10
])
features
=
np
.
random
.
randint
(
0
,
10
,
[
10
0
,
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
])
labels
=
np
.
random
.
randint
(
0
,
2
,
[
10
0
])
config_rf
=
{
config_rf
=
{
"method"
:
"RF"
,
"method"
:
"RF"
,
"params"
:
{
"params"
:
{
"n_estimators"
:
[
100
],
"n_estimators"
:
[
100
],
"max_features"
:
[
"auto"
,
"sqrt"
],
"max_features"
:
[
"auto"
,
"sqrt"
],
"max_depth"
:
[
5
,
10
,
20
,
None
],
"max_depth"
:
[
None
,
5
,
10
,
20
],
"min_samples_split"
:
[
2
,
5
,
10
],
"min_samples_split"
:
[
2
,
5
,
10
],
"min_samples_leaf"
:
[
1
,
2
,
4
],
"min_samples_leaf"
:
[
1
,
2
,
4
],
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录