diff --git a/mindarmour/privacy/evaluation/_check_config.py b/mindarmour/privacy/evaluation/_check_config.py index d6e0a2b87bc3553cd0139ac99b1cafa450546e22..86155e37ef88545a14ee5cb3d8365e5311d0fca3 100644 --- a/mindarmour/privacy/evaluation/_check_config.py +++ b/mindarmour/privacy/evaluation/_check_config.py @@ -15,11 +15,12 @@ Verify attack config """ +from mindarmour.utils._check_param import check_param_type from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() -TAG = "check_params" +TAG = "check_config" def _is_positive_int(item): @@ -77,7 +78,7 @@ def _is_dict(item): return isinstance(item, dict) -VALID_PARAMS_DICT = { +_VALID_CONFIG_CHECKLIST = { "knn": { "n_neighbors": [_is_positive_int], "weights": [{"uniform", "distance"}], @@ -126,7 +127,7 @@ VALID_PARAMS_DICT = { "rf": { "n_estimators": [_is_positive_int], "criterion": [{"gini", "entropy"}], - "max_depth": [_is_positive_int], + "max_depth": [{None}, _is_positive_int], "min_samples_split": [_is_positive_float], "min_samples_leaf": [_is_positive_float], "min_weight_fraction_leaf": [_is_non_negative_float], @@ -148,24 +149,15 @@ VALID_PARAMS_DICT = { -def _check_config(config_list, check_params): +def _check_config(attack_config, config_checklist): """ Verify that config_list is valid. Check_params is the valid value range of the parameter. """ - if not isinstance(config_list, (list, tuple)): - msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - for config in config_list: - if not isinstance(config, dict): - msg = "Type of each config in config_list must be dict, but got {}.".format(type(config)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - + for config in attack_config: + check_param_type("config", config, dict) if set(config.keys()) != {"params", "method"}: - msg = "Keys of each config in config_list must be {}," \ + msg = "Keys of each config in attack_config must be {}," \ "but got {}.".format({'method', 'params'}, set(config.keys())) LOGGER.error(TAG, msg) raise KeyError(msg) @@ -173,27 +165,22 @@ def _check_config(config_list, check_params): method = str.lower(config["method"]) params = config["params"] - if method not in check_params.keys(): + if method not in config_checklist.keys(): msg = "Method {} is not supported.".format(method) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise NameError(msg) - if not params.keys() <= check_params[method].keys(): + if not params.keys() <= config_checklist[method].keys(): msg = "Params in method {} is not accepted, the parameters " \ - "that can be set are {}.".format(method, set(check_params[method].keys())) + "that can be set are {}.".format(method, set(config_checklist[method].keys())) LOGGER.error(TAG, msg) raise KeyError(msg) for param_key in params.keys(): param_value = params[param_key] - candidate_values = check_params[method][param_key] - - if not isinstance(param_value, list): - msg = "The parameter '{}' in method '{}' setting must within the range of " \ - "changeable parameters.".format(param_key, method) - LOGGER.error(TAG, msg) - raise ValueError(msg) + candidate_values = config_checklist[method][param_key] + check_param_type('param_value', param_value, list) if candidate_values is None: continue @@ -204,7 +191,7 @@ def _check_config(config_list, check_params): if isinstance(candidate_value, set) and item_value in candidate_value: flag = True break - elif candidate_value(item_value): + elif not isinstance(candidate_value, set) and candidate_value(item_value): flag = True break @@ -213,8 +200,8 @@ def _check_config(config_list, check_params): raise ValueError(msg) -def check_config_params(config_list): +def verify_config_params(attack_config): """ External interfaces to verify attack config. """ - _check_config(config_list, VALID_PARAMS_DICT) + _check_config(attack_config, _VALID_CONFIG_CHECKLIST) diff --git a/mindarmour/privacy/evaluation/attacker.py b/mindarmour/privacy/evaluation/attacker.py index 3b337e0b9e7e052fa634f0632337ce6242ea3a83..5733a8338a886f16ada77b97e072cb0d6baf787e 100644 --- a/mindarmour/privacy/evaluation/attacker.py +++ b/mindarmour/privacy/evaluation/attacker.py @@ -153,4 +153,4 @@ def get_attack_model(features, labels, config, n_jobs=-1): msg = "Method {} is not supported.".format(config["method"]) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise NameError(msg) diff --git a/mindarmour/privacy/evaluation/membership_inference.py b/mindarmour/privacy/evaluation/membership_inference.py index 5de6d5be0089c40bc53b784c3c9de54e28e4d746..3d4f88a0464ff001531e52d9577d54ebc00cca0c 100644 --- a/mindarmour/privacy/evaluation/membership_inference.py +++ b/mindarmour/privacy/evaluation/membership_inference.py @@ -23,8 +23,10 @@ from mindspore.train import Model from mindspore.dataset.engine import Dataset from mindspore import Tensor from mindarmour.utils.logger import LogUtil +from mindarmour.utils._check_param import check_param_type, check_param_multi_types, \ + check_model, check_numpy_param from .attacker import get_attack_model -from ._check_config import check_config_params +from ._check_config import verify_config_params LOGGER = LogUtil.get_instance() TAG = "MembershipInference" @@ -47,23 +49,21 @@ def _eval_info(pred, truth, option): ValueError, size of parameter pred or truth is 0. ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. """ - if pred.size == 0 or truth.size == 0: - msg = "Size of pred or truth is 0." - LOGGER.error(TAG, msg) - raise ValueError(msg) + check_numpy_param("pred", pred) + check_numpy_param("truth", truth) if option == "accuracy": count = np.sum(pred == truth) return count / len(pred) if option == "precision": - count = np.sum(pred & truth) if np.sum(pred) == 0: return -1 + count = np.sum(pred & truth) return count / np.sum(pred) if option == "recall": - count = np.sum(pred & truth) if np.sum(truth) == 0: return -1 + count = np.sum(pred & truth) return count / np.sum(truth) msg = "The metric value {} is undefined.".format(option) @@ -107,9 +107,9 @@ class MembershipInference: otherwise the value of n_jobs must be a positive integer. Examples: - >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. - >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. - >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. + >>> # train_1, train_2 are non-overlapping datasets from training dataset of target model. + >>> # test_1, test_2 are non-overlapping datasets from test dataset of target model. + >>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) >>> inference_model = MembershipInference(model, n_jobs=-1) >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] @@ -124,65 +124,44 @@ class MembershipInference: """ def __init__(self, model, n_jobs=-1): - if not isinstance(model, Model): - msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - if not isinstance(n_jobs, int): - msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs)) - LOGGER.error(TAG, msg) - raise TypeError(msg) + check_param_type("n_jobs", n_jobs, int) if not (n_jobs == -1 or n_jobs > 0): msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs) LOGGER.error(TAG, msg) raise ValueError(msg) - self.model = model - self.n_jobs = min(n_jobs, cpu_count()) - self.method_list = ["knn", "lr", "mlp", "rf"] - self.attack_list = [] + self._model = check_model("model", model, Model) + self._n_jobs = min(n_jobs, cpu_count()) + self._attack_list = [] def train(self, dataset_train, dataset_test, attack_config): """ - Depending on the configuration, use the incoming data set to train the attack model. - Save the attack model to self.attack_list. + Depending on the configuration, use the input data set to train the attack model. + Save the attack model to self._attack_list. Args: dataset_train (mindspore.dataset): The training dataset for the target model. dataset_test (mindspore.dataset): The test set for the target model. - attack_config (list): Parameter setting for the attack model. The format is + attack_config (Union[list, tuple]): Parameter setting for the attack model. The format is [{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, {"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. - The support methods list is in self.method_list, and the params of each method + The support methods are knn, lr, mlp and rf, and the params of each method must within the range of changeable parameters. Tips of params implement can be found in "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". Raises: - KeyError: If each config in attack_config doesn't have keys {"method", "params"} - ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. + KeyError: If any config in attack_config doesn't have keys {"method", "params"} + NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. """ - if not isinstance(dataset_train, Dataset): - msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - if not isinstance(dataset_test, Dataset): - msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - if not isinstance(attack_config, list): - msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - check_config_params(attack_config) # Verify attack config. + check_param_type("dataset_train", dataset_train, Dataset) + check_param_type("dataset_test", dataset_test, Dataset) + check_param_multi_types("attack_config", attack_config, (list, tuple)) + verify_config_params(attack_config) features, labels = self._transform(dataset_train, dataset_test) - for config in attack_config: - self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs)) + self._attack_list.append(get_attack_model(features, labels, config, n_jobs=self._n_jobs)) def eval(self, dataset_train, dataset_test, metrics): @@ -199,20 +178,9 @@ class MembershipInference: Returns: list, Each element contains an evaluation indicator for the attack model. """ - if not isinstance(dataset_train, Dataset): - msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - if not isinstance(dataset_test, Dataset): - msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - - if not isinstance(metrics, (list, tuple)): - msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics)) - LOGGER.error(TAG, msg) - raise TypeError(msg) + check_param_type("dataset_train", dataset_train, Dataset) + check_param_type("dataset_test", dataset_test, Dataset) + check_param_multi_types("metrics", metrics, (list, tuple)) metrics = set(metrics) metrics_list = {"precision", "accuracy", "recall"} @@ -223,7 +191,7 @@ class MembershipInference: result = [] features, labels = self._transform(dataset_train, dataset_test) - for attacker in self.attack_list: + for attacker in self._attack_list: pred = attacker.predict(features) item = {} for option in metrics: @@ -233,7 +201,7 @@ class MembershipInference: def _transform(self, dataset_train, dataset_test): """ - Generate corresponding loss_logits feature and new label, and return after shuffle. + Generate corresponding loss_logits features and new label, and return after shuffle. Args: dataset_train: The training set for the target model. @@ -255,13 +223,13 @@ class MembershipInference: return features, labels - def _generate(self, dataset_x, label): + def _generate(self, input_dataset, label): """ Return a loss_logits features and labels for training attack model. Args: - dataset_x (mindspore.dataset): The dataset to be generate. - label (int32): Whether dataset_x belongs to the target model. + input_dataset (mindspore.dataset): The dataset to be generate. + label (int32): Whether input_dataset belongs to the target model. Returns: - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). @@ -269,10 +237,10 @@ class MembershipInference: - numpy.ndarray, Labels for each sample, Shape is (N,). """ loss_logits = np.array([]) - for batch in dataset_x.create_dict_iterator(): + for batch in input_dataset.create_dict_iterator(): batch_data = Tensor(batch['image'], ms.float32) batch_labels = batch['label'].astype(np.int32) - batch_logits = self.model.predict(batch_data).asnumpy() + batch_logits = self._model.predict(batch_data).asnumpy() batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) diff --git a/tests/ut/python/diff_privacy/test_attacker.py b/tests/ut/python/diff_privacy/test_attacker.py index 013b4918c4651f171bfae3c91eff2f9063398da7..9e374b0db3dcaec0e620fe1de498837937267282 100644 --- a/tests/ut/python/diff_privacy/test_attacker.py +++ b/tests/ut/python/diff_privacy/test_attacker.py @@ -27,12 +27,12 @@ from mindarmour.privacy.evaluation.attacker import get_attack_model @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_get_knn_model(): - features = np.random.randint(0, 10, [10, 10]) - labels = np.random.randint(0, 2, [10]) + features = np.random.randint(0, 10, [100, 10]) + labels = np.random.randint(0, 2, [100]) config_knn = { "method": "KNN", "params": { - "n_neighbors": [3], + "n_neighbors": [3, 5, 7], } } knn_attacker = get_attack_model(features, labels, config_knn, -1) @@ -46,8 +46,8 @@ def test_get_knn_model(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_get_lr_model(): - features = np.random.randint(0, 10, [10, 10]) - labels = np.random.randint(0, 2, [10]) + features = np.random.randint(0, 10, [100, 10]) + labels = np.random.randint(0, 2, [100]) config_lr = { "method": "LR", "params": { @@ -65,8 +65,8 @@ def test_get_lr_model(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_get_mlp_model(): - features = np.random.randint(0, 10, [10, 10]) - labels = np.random.randint(0, 2, [10]) + features = np.random.randint(0, 10, [100, 10]) + labels = np.random.randint(0, 2, [100]) config_mlpc = { "method": "MLP", "params": { @@ -86,14 +86,14 @@ def test_get_mlp_model(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_get_rf_model(): - features = np.random.randint(0, 10, [10, 10]) - labels = np.random.randint(0, 2, [10]) + features = np.random.randint(0, 10, [100, 10]) + labels = np.random.randint(0, 2, [100]) config_rf = { "method": "RF", "params": { "n_estimators": [100], "max_features": ["auto", "sqrt"], - "max_depth": [5, 10, 20, None], + "max_depth": [None, 5, 10, 20], "min_samples_split": [2, 5, 10], "min_samples_leaf": [1, 2, 4], }