提交 5541b704 编写于 作者: P pkuliuliu

add third-party dependence and pytest mark

上级 0d903144
...@@ -121,10 +121,10 @@ def get_attack_model(features, labels, config): ...@@ -121,10 +121,10 @@ def get_attack_model(features, labels, config):
method = str.lower(config["method"]) method = str.lower(config["method"])
if method == "knn": if method == "knn":
return _attack_knn(features, labels, config["params"]) return _attack_knn(features, labels, config["params"])
if method == "LR": if method in ["lr", "logitic regression"]:
return _attack_lr(features, labels, config["params"]) return _attack_lr(features, labels, config["params"])
if method == "MLP": if method == "mlp":
return _attack_mlpc(features, labels, config["params"]) return _attack_mlpc(features, labels, config["params"])
if method == "RF": if method in ["rf", "random forest"]:
return _attack_rf(features, labels, config["params"]) return _attack_rf(features, labels, config["params"])
raise ValueError("Method {} is not support.".format(config["method"])) raise ValueError("Method {} is not support.".format(config["method"]))
...@@ -5,3 +5,4 @@ Pillow >= 2.0.0 ...@@ -5,3 +5,4 @@ Pillow >= 2.0.0
pytest >= 4.3.1 pytest >= 4.3.1
wheel >= 0.32.0 wheel >= 0.32.0
setuptools >= 40.8.0 setuptools >= 40.8.0
scikit-learn >= 0.21.2
...@@ -105,7 +105,8 @@ setup( ...@@ -105,7 +105,8 @@ setup(
'scipy >= 1.3.3', 'scipy >= 1.3.3',
'numpy >= 1.17.0', 'numpy >= 1.17.0',
'matplotlib >= 3.2.1', 'matplotlib >= 3.2.1',
'Pillow >= 2.0.0' 'Pillow >= 2.0.0',
'scikit-learn >= 0.21.2'
], ],
classifiers=[ classifiers=[
'License :: OSI Approved :: Apache Software License' 'License :: OSI Approved :: Apache Software License'
......
...@@ -18,16 +18,12 @@ import pytest ...@@ -18,16 +18,12 @@ import pytest
import numpy as np import numpy as np
from sklearn.neighbors import KNeighborsClassifier as knn
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_knn_model(): def test_get_knn_model():
...@@ -36,17 +32,17 @@ def test_get_knn_model(): ...@@ -36,17 +32,17 @@ def test_get_knn_model():
config_knn = { config_knn = {
"method": "KNN", "method": "KNN",
"params": { "params": {
"n_neighbors": [3, 5, 7], "n_neighbors": [3],
} }
} }
knn_attacker = get_attack_model(features, labels, config_knn) knn_attacker = get_attack_model(features, labels, config_knn)
assert isinstance(knn_attacker, knn)
pred = knn_attacker.predict(features) pred = knn_attacker.predict(features)
assert pred is not None assert pred is not None
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_lr_model(): def test_get_lr_model():
...@@ -59,13 +55,13 @@ def test_get_lr_model(): ...@@ -59,13 +55,13 @@ def test_get_lr_model():
} }
} }
lr_attacker = get_attack_model(features, labels, config_lr) lr_attacker = get_attack_model(features, labels, config_lr)
assert isinstance(lr_attacker, LogisticRegression)
pred = lr_attacker.predict(features) pred = lr_attacker.predict(features)
assert pred is not None assert pred is not None
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_mlp_model(): def test_get_mlp_model():
...@@ -80,13 +76,13 @@ def test_get_mlp_model(): ...@@ -80,13 +76,13 @@ def test_get_mlp_model():
} }
} }
mlpc_attacker = get_attack_model(features, labels, config_mlpc) mlpc_attacker = get_attack_model(features, labels, config_mlpc)
assert isinstance(mlpc_attacker, MLPClassifier)
pred = mlpc_attacker.predict(features) pred = mlpc_attacker.predict(features)
assert pred is not None assert pred is not None
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_rf_model(): def test_get_rf_model():
...@@ -103,6 +99,5 @@ def test_get_rf_model(): ...@@ -103,6 +99,5 @@ def test_get_rf_model():
} }
} }
rf_attacker = get_attack_model(features, labels, config_rf) rf_attacker = get_attack_model(features, labels, config_rf)
assert isinstance(rf_attacker, RandomForestClassifier)
pred = rf_attacker.predict(features) pred = rf_attacker.predict(features)
assert pred is not None assert pred is not None
...@@ -27,6 +27,7 @@ from mindarmour.diff_privacy import ClipMechanismsFactory ...@@ -27,6 +27,7 @@ from mindarmour.diff_privacy import ClipMechanismsFactory
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_graph_factory(): def test_graph_factory():
...@@ -53,6 +54,7 @@ def test_graph_factory(): ...@@ -53,6 +54,7 @@ def test_graph_factory():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_pynative_factory(): def test_pynative_factory():
...@@ -79,6 +81,7 @@ def test_pynative_factory(): ...@@ -79,6 +81,7 @@ def test_pynative_factory():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_pynative_gaussian(): def test_pynative_gaussian():
...@@ -105,6 +108,7 @@ def test_pynative_gaussian(): ...@@ -105,6 +108,7 @@ def test_pynative_gaussian():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_graph_ada_gaussian(): def test_graph_ada_gaussian():
...@@ -125,6 +129,7 @@ def test_graph_ada_gaussian(): ...@@ -125,6 +129,7 @@ def test_graph_ada_gaussian():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_pynative_ada_gaussian(): def test_pynative_ada_gaussian():
...@@ -145,6 +150,7 @@ def test_pynative_ada_gaussian(): ...@@ -145,6 +150,7 @@ def test_pynative_ada_gaussian():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_graph_exponential(): def test_graph_exponential():
...@@ -166,6 +172,7 @@ def test_graph_exponential(): ...@@ -166,6 +172,7 @@ def test_graph_exponential():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_pynative_exponential(): def test_pynative_exponential():
...@@ -187,6 +194,7 @@ def test_pynative_exponential(): ...@@ -187,6 +194,7 @@ def test_pynative_exponential():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_ada_clip_gaussian_random_pynative(): def test_ada_clip_gaussian_random_pynative():
...@@ -217,6 +225,7 @@ def test_ada_clip_gaussian_random_pynative(): ...@@ -217,6 +225,7 @@ def test_ada_clip_gaussian_random_pynative():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_ada_clip_gaussian_random_graph(): def test_ada_clip_gaussian_random_graph():
...@@ -247,6 +256,7 @@ def test_ada_clip_gaussian_random_graph(): ...@@ -247,6 +256,7 @@ def test_ada_clip_gaussian_random_graph():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_pynative_clip_mech_factory(): def test_pynative_clip_mech_factory():
...@@ -269,6 +279,7 @@ def test_pynative_clip_mech_factory(): ...@@ -269,6 +279,7 @@ def test_pynative_clip_mech_factory():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_graph_clip_mech_factory(): def test_graph_clip_mech_factory():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册