提交 2ded64d6 编写于 作者: L liuluobin

Correct some docs error. Modify the type detection code.

上级 8b142a22
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
Verify attack config Verify attack config
""" """
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = "check_params" TAG = "check_config"
def _is_positive_int(item): def _is_positive_int(item):
...@@ -77,7 +78,7 @@ def _is_dict(item): ...@@ -77,7 +78,7 @@ def _is_dict(item):
return isinstance(item, dict) return isinstance(item, dict)
VALID_PARAMS_DICT = { _VALID_CONFIG_CHECKLIST = {
"knn": { "knn": {
"n_neighbors": [_is_positive_int], "n_neighbors": [_is_positive_int],
"weights": [{"uniform", "distance"}], "weights": [{"uniform", "distance"}],
...@@ -126,7 +127,7 @@ VALID_PARAMS_DICT = { ...@@ -126,7 +127,7 @@ VALID_PARAMS_DICT = {
"rf": { "rf": {
"n_estimators": [_is_positive_int], "n_estimators": [_is_positive_int],
"criterion": [{"gini", "entropy"}], "criterion": [{"gini", "entropy"}],
"max_depth": [_is_positive_int], "max_depth": [{None}, _is_positive_int],
"min_samples_split": [_is_positive_float], "min_samples_split": [_is_positive_float],
"min_samples_leaf": [_is_positive_float], "min_samples_leaf": [_is_positive_float],
"min_weight_fraction_leaf": [_is_non_negative_float], "min_weight_fraction_leaf": [_is_non_negative_float],
...@@ -148,24 +149,15 @@ VALID_PARAMS_DICT = { ...@@ -148,24 +149,15 @@ VALID_PARAMS_DICT = {
def _check_config(config_list, check_params): def _check_config(attack_config, config_checklist):
""" """
Verify that config_list is valid. Verify that config_list is valid.
Check_params is the valid value range of the parameter. Check_params is the valid value range of the parameter.
""" """
if not isinstance(config_list, (list, tuple)): for config in attack_config:
msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list)) check_param_type("config", config, dict)
LOGGER.error(TAG, msg)
raise TypeError(msg)
for config in config_list:
if not isinstance(config, dict):
msg = "Type of each config in config_list must be dict, but got {}.".format(type(config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if set(config.keys()) != {"params", "method"}: if set(config.keys()) != {"params", "method"}:
msg = "Keys of each config in config_list must be {}," \ msg = "Keys of each config in attack_config must be {}," \
"but got {}.".format({'method', 'params'}, set(config.keys())) "but got {}.".format({'method', 'params'}, set(config.keys()))
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise KeyError(msg) raise KeyError(msg)
...@@ -173,27 +165,22 @@ def _check_config(config_list, check_params): ...@@ -173,27 +165,22 @@ def _check_config(config_list, check_params):
method = str.lower(config["method"]) method = str.lower(config["method"])
params = config["params"] params = config["params"]
if method not in check_params.keys(): if method not in config_checklist.keys():
msg = "Method {} is not supported.".format(method) msg = "Method {} is not supported.".format(method)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise NameError(msg)
if not params.keys() <= check_params[method].keys(): if not params.keys() <= config_checklist[method].keys():
msg = "Params in method {} is not accepted, the parameters " \ msg = "Params in method {} is not accepted, the parameters " \
"that can be set are {}.".format(method, set(check_params[method].keys())) "that can be set are {}.".format(method, set(config_checklist[method].keys()))
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise KeyError(msg) raise KeyError(msg)
for param_key in params.keys(): for param_key in params.keys():
param_value = params[param_key] param_value = params[param_key]
candidate_values = check_params[method][param_key] candidate_values = config_checklist[method][param_key]
check_param_type('param_value', param_value, list)
if not isinstance(param_value, list):
msg = "The parameter '{}' in method '{}' setting must within the range of " \
"changeable parameters.".format(param_key, method)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if candidate_values is None: if candidate_values is None:
continue continue
...@@ -204,7 +191,7 @@ def _check_config(config_list, check_params): ...@@ -204,7 +191,7 @@ def _check_config(config_list, check_params):
if isinstance(candidate_value, set) and item_value in candidate_value: if isinstance(candidate_value, set) and item_value in candidate_value:
flag = True flag = True
break break
elif candidate_value(item_value): elif not isinstance(candidate_value, set) and candidate_value(item_value):
flag = True flag = True
break break
...@@ -213,8 +200,8 @@ def _check_config(config_list, check_params): ...@@ -213,8 +200,8 @@ def _check_config(config_list, check_params):
raise ValueError(msg) raise ValueError(msg)
def check_config_params(config_list): def verify_config_params(attack_config):
""" """
External interfaces to verify attack config. External interfaces to verify attack config.
""" """
_check_config(config_list, VALID_PARAMS_DICT) _check_config(attack_config, _VALID_CONFIG_CHECKLIST)
...@@ -153,4 +153,4 @@ def get_attack_model(features, labels, config, n_jobs=-1): ...@@ -153,4 +153,4 @@ def get_attack_model(features, labels, config, n_jobs=-1):
msg = "Method {} is not supported.".format(config["method"]) msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise NameError(msg)
...@@ -23,8 +23,10 @@ from mindspore.train import Model ...@@ -23,8 +23,10 @@ from mindspore.train import Model
from mindspore.dataset.engine import Dataset from mindspore.dataset.engine import Dataset
from mindspore import Tensor from mindspore import Tensor
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \
check_model, check_numpy_param
from .attacker import get_attack_model from .attacker import get_attack_model
from ._check_config import check_config_params from ._check_config import verify_config_params
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = "MembershipInference" TAG = "MembershipInference"
...@@ -47,23 +49,21 @@ def _eval_info(pred, truth, option): ...@@ -47,23 +49,21 @@ def _eval_info(pred, truth, option):
ValueError, size of parameter pred or truth is 0. ValueError, size of parameter pred or truth is 0.
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
""" """
if pred.size == 0 or truth.size == 0: check_numpy_param("pred", pred)
msg = "Size of pred or truth is 0." check_numpy_param("truth", truth)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if option == "accuracy": if option == "accuracy":
count = np.sum(pred == truth) count = np.sum(pred == truth)
return count / len(pred) return count / len(pred)
if option == "precision": if option == "precision":
count = np.sum(pred & truth)
if np.sum(pred) == 0: if np.sum(pred) == 0:
return -1 return -1
count = np.sum(pred & truth)
return count / np.sum(pred) return count / np.sum(pred)
if option == "recall": if option == "recall":
count = np.sum(pred & truth)
if np.sum(truth) == 0: if np.sum(truth) == 0:
return -1 return -1
count = np.sum(pred & truth)
return count / np.sum(truth) return count / np.sum(truth)
msg = "The metric value {} is undefined.".format(option) msg = "The metric value {} is undefined.".format(option)
...@@ -107,9 +107,9 @@ class MembershipInference: ...@@ -107,9 +107,9 @@ class MembershipInference:
otherwise the value of n_jobs must be a positive integer. otherwise the value of n_jobs must be a positive integer.
Examples: Examples:
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model. >>> # train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model. >>> # test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. >>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model, n_jobs=-1) >>> inference_model = MembershipInference(model, n_jobs=-1)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
...@@ -124,65 +124,44 @@ class MembershipInference: ...@@ -124,65 +124,44 @@ class MembershipInference:
""" """
def __init__(self, model, n_jobs=-1): def __init__(self, model, n_jobs=-1):
if not isinstance(model, Model): check_param_type("n_jobs", n_jobs, int)
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(n_jobs, int):
msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not (n_jobs == -1 or n_jobs > 0): if not (n_jobs == -1 or n_jobs > 0):
msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise ValueError(msg)
self.model = model self._model = check_model("model", model, Model)
self.n_jobs = min(n_jobs, cpu_count()) self._n_jobs = min(n_jobs, cpu_count())
self.method_list = ["knn", "lr", "mlp", "rf"] self._attack_list = []
self.attack_list = []
def train(self, dataset_train, dataset_test, attack_config): def train(self, dataset_train, dataset_test, attack_config):
""" """
Depending on the configuration, use the incoming data set to train the attack model. Depending on the configuration, use the input data set to train the attack model.
Save the attack model to self.attack_list. Save the attack model to self._attack_list.
Args: Args:
dataset_train (mindspore.dataset): The training dataset for the target model. dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test set for the target model. dataset_test (mindspore.dataset): The test set for the target model.
attack_config (list): Parameter setting for the attack model. The format is attack_config (Union[list, tuple]): Parameter setting for the attack model. The format is
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, [{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. {"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}].
The support methods list is in self.method_list, and the params of each method The support methods are knn, lr, mlp and rf, and the params of each method
must within the range of changeable parameters. Tips of params implement must within the range of changeable parameters. Tips of params implement
can be found in can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Raises: Raises:
KeyError: If each config in attack_config doesn't have keys {"method", "params"} KeyError: If any config in attack_config doesn't have keys {"method", "params"}
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
""" """
if not isinstance(dataset_train, Dataset): check_param_type("dataset_train", dataset_train, Dataset)
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) check_param_type("dataset_test", dataset_test, Dataset)
LOGGER.error(TAG, msg) check_param_multi_types("attack_config", attack_config, (list, tuple))
raise TypeError(msg) verify_config_params(attack_config)
if not isinstance(dataset_test, Dataset):
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(attack_config, list):
msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
check_config_params(attack_config) # Verify attack config.
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config: for config in attack_config:
self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) self._attack_list.append(get_attack_model(features, labels, config, n_jobs=self._n_jobs))
def eval(self, dataset_train, dataset_test, metrics): def eval(self, dataset_train, dataset_test, metrics):
...@@ -199,20 +178,9 @@ class MembershipInference: ...@@ -199,20 +178,9 @@ class MembershipInference:
Returns: Returns:
list, Each element contains an evaluation indicator for the attack model. list, Each element contains an evaluation indicator for the attack model.
""" """
if not isinstance(dataset_train, Dataset): check_param_type("dataset_train", dataset_train, Dataset)
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) check_param_type("dataset_test", dataset_test, Dataset)
LOGGER.error(TAG, msg) check_param_multi_types("metrics", metrics, (list, tuple))
raise TypeError(msg)
if not isinstance(dataset_test, Dataset):
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(metrics, (list, tuple)):
msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics))
LOGGER.error(TAG, msg)
raise TypeError(msg)
metrics = set(metrics) metrics = set(metrics)
metrics_list = {"precision", "accuracy", "recall"} metrics_list = {"precision", "accuracy", "recall"}
...@@ -223,7 +191,7 @@ class MembershipInference: ...@@ -223,7 +191,7 @@ class MembershipInference:
result = [] result = []
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for attacker in self.attack_list: for attacker in self._attack_list:
pred = attacker.predict(features) pred = attacker.predict(features)
item = {} item = {}
for option in metrics: for option in metrics:
...@@ -233,7 +201,7 @@ class MembershipInference: ...@@ -233,7 +201,7 @@ class MembershipInference:
def _transform(self, dataset_train, dataset_test): def _transform(self, dataset_train, dataset_test):
""" """
Generate corresponding loss_logits feature and new label, and return after shuffle. Generate corresponding loss_logits features and new label, and return after shuffle.
Args: Args:
dataset_train: The training set for the target model. dataset_train: The training set for the target model.
...@@ -255,13 +223,13 @@ class MembershipInference: ...@@ -255,13 +223,13 @@ class MembershipInference:
return features, labels return features, labels
def _generate(self, dataset_x, label): def _generate(self, input_dataset, label):
""" """
Return a loss_logits features and labels for training attack model. Return a loss_logits features and labels for training attack model.
Args: Args:
dataset_x (mindspore.dataset): The dataset to be generate. input_dataset (mindspore.dataset): The dataset to be generate.
label (int32): Whether dataset_x belongs to the target model. label (int32): Whether input_dataset belongs to the target model.
Returns: Returns:
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C).
...@@ -269,10 +237,10 @@ class MembershipInference: ...@@ -269,10 +237,10 @@ class MembershipInference:
- numpy.ndarray, Labels for each sample, Shape is (N,). - numpy.ndarray, Labels for each sample, Shape is (N,).
""" """
loss_logits = np.array([]) loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator(): for batch in input_dataset.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32) batch_data = Tensor(batch['image'], ms.float32)
batch_labels = batch['label'].astype(np.int32) batch_labels = batch['label'].astype(np.int32)
batch_logits = self.model.predict(batch_data).asnumpy() batch_logits = self._model.predict(batch_data).asnumpy()
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) batch_loss = _softmax_cross_entropy(batch_logits, batch_labels)
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits))
......
...@@ -27,12 +27,12 @@ from mindarmour.privacy.evaluation.attacker import get_attack_model ...@@ -27,12 +27,12 @@ from mindarmour.privacy.evaluation.attacker import get_attack_model
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_knn_model(): def test_get_knn_model():
features = np.random.randint(0, 10, [10, 10]) features = np.random.randint(0, 10, [100, 10])
labels = np.random.randint(0, 2, [10]) labels = np.random.randint(0, 2, [100])
config_knn = { config_knn = {
"method": "KNN", "method": "KNN",
"params": { "params": {
"n_neighbors": [3], "n_neighbors": [3, 5, 7],
} }
} }
knn_attacker = get_attack_model(features, labels, config_knn, -1) knn_attacker = get_attack_model(features, labels, config_knn, -1)
...@@ -46,8 +46,8 @@ def test_get_knn_model(): ...@@ -46,8 +46,8 @@ def test_get_knn_model():
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_lr_model(): def test_get_lr_model():
features = np.random.randint(0, 10, [10, 10]) features = np.random.randint(0, 10, [100, 10])
labels = np.random.randint(0, 2, [10]) labels = np.random.randint(0, 2, [100])
config_lr = { config_lr = {
"method": "LR", "method": "LR",
"params": { "params": {
...@@ -65,8 +65,8 @@ def test_get_lr_model(): ...@@ -65,8 +65,8 @@ def test_get_lr_model():
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_mlp_model(): def test_get_mlp_model():
features = np.random.randint(0, 10, [10, 10]) features = np.random.randint(0, 10, [100, 10])
labels = np.random.randint(0, 2, [10]) labels = np.random.randint(0, 2, [100])
config_mlpc = { config_mlpc = {
"method": "MLP", "method": "MLP",
"params": { "params": {
...@@ -86,14 +86,14 @@ def test_get_mlp_model(): ...@@ -86,14 +86,14 @@ def test_get_mlp_model():
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_get_rf_model(): def test_get_rf_model():
features = np.random.randint(0, 10, [10, 10]) features = np.random.randint(0, 10, [100, 10])
labels = np.random.randint(0, 2, [10]) labels = np.random.randint(0, 2, [100])
config_rf = { config_rf = {
"method": "RF", "method": "RF",
"params": { "params": {
"n_estimators": [100], "n_estimators": [100],
"max_features": ["auto", "sqrt"], "max_features": ["auto", "sqrt"],
"max_depth": [5, 10, 20, None], "max_depth": [None, 5, 10, 20],
"min_samples_split": [2, 5, 10], "min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4], "min_samples_leaf": [1, 2, 4],
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册