diff --git a/mindarmour/privacy/evaluation/_check_config.py b/mindarmour/privacy/evaluation/_check_config.py index 86155e37ef88545a14ee5cb3d8365e5311d0fca3..2c65f616c81ef3d97c7e8ebba404ed3772563ce4 100644 --- a/mindarmour/privacy/evaluation/_check_config.py +++ b/mindarmour/privacy/evaluation/_check_config.py @@ -24,45 +24,39 @@ TAG = "check_config" def _is_positive_int(item): - """ - Verify that the value is a positive integer. - """ - if not isinstance(item, int) or item <= 0: + """Verify that the value is a positive integer.""" + if not isinstance(item, int): return False - return True - + return item > 0 def _is_non_negative_int(item): - """ - Verify that the value is a non-negative integer. - """ - if not isinstance(item, int) or item < 0: + """Verify that the value is a non-negative integer.""" + if not isinstance(item, int): return False - return True + return item >= 0 def _is_positive_float(item): - """ - Verify that value is a positive number. - """ - if not isinstance(item, (int, float)) or item <= 0: + """Verify that value is a positive number.""" + if not isinstance(item, (int, float)): return False - return True + return item > 0 def _is_non_negative_float(item): - """ - Verify that value is a non-negative number. - """ - if not isinstance(item, (int, float)) or item < 0: + """Verify that value is a non-negative number.""" + if not isinstance(item, (int, float)): return False - return True + return item >= 0 + +def _is_range_0_1_float(item): + if not isinstance(item, (int, float)): + return False + return 0 <= item < 1 def _is_positive_int_tuple(item): - """ - Verify that the input parameter is a positive integer tuple. - """ + """Verify that the input parameter is a positive integer tuple.""" if not isinstance(item, tuple): return False for i in item: @@ -72,21 +66,29 @@ def _is_positive_int_tuple(item): def _is_dict(item): - """ - Check whether the type is dict. - """ + """Check whether the type is dict.""" return isinstance(item, dict) +def _is_list(item): + """Check whether the type is list""" + return isinstance(item, list) + + +def _is_str(item): + """Check whether the type is str.""" + return isinstance(item, str) + + _VALID_CONFIG_CHECKLIST = { "knn": { "n_neighbors": [_is_positive_int], - "weights": [{"uniform", "distance"}], + "weights": [{"uniform", "distance"}, callable], "algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}], "leaf_size": [_is_positive_int], "p": [_is_positive_int], - "metric": None, - "metric_params": None, + "metric": [_is_str, callable], + "metric_params": [_is_dict, {None}] }, "lr": { "penalty": [{"l1", "l2", "elasticnet", "none"}], @@ -102,7 +104,7 @@ _VALID_CONFIG_CHECKLIST = { "mlp": { "hidden_layer_sizes": [_is_positive_int_tuple], "activation": [{"identity", "logistic", "tanh", "relu"}], - "solver": {"lbfgs", "sgd", "adam"}, + "solver": [{"lbfgs", "sgd", "adam"}], "alpha": [_is_positive_float], "batch_size": [{"auto"}, _is_positive_int], "learning_rate": [{"constant", "invscaling", "adaptive"}], @@ -117,9 +119,9 @@ _VALID_CONFIG_CHECKLIST = { "momentum": [_is_positive_float], "nesterovs_momentum": [{True, False}], "early_stopping": [{True, False}], - "validation_fraction": [_is_positive_float], - "beta_1": [_is_positive_float], - "beta_2": [_is_positive_float], + "validation_fraction": [_is_range_0_1_float], + "beta_1": [_is_range_0_1_float], + "beta_2": [_is_range_0_1_float], "epsilon": [_is_positive_float], "n_iter_no_change": [_is_positive_int], "max_fun": [_is_positive_int] @@ -133,7 +135,7 @@ _VALID_CONFIG_CHECKLIST = { "min_weight_fraction_leaf": [_is_non_negative_float], "max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float], "max_leaf_nodes": [_is_positive_int, {None}], - "min_impurity_decrease": {_is_non_negative_float}, + "min_impurity_decrease": [_is_non_negative_float], "min_impurity_split": [{None}, _is_positive_float], "bootstrap": [{True, False}], "oob_scroe": [{True, False}], @@ -141,9 +143,9 @@ _VALID_CONFIG_CHECKLIST = { "random_state": None, "verbose": [_is_non_negative_int], "warm_start": [{True, False}], - "class_weight": None, + "class_weight": [{"balanced", "balanced_subsample"}, _is_dict, _is_list], "ccp_alpha": [_is_non_negative_float], - "max_samples": [_is_positive_float] + "max_samples": [{None}, _is_positive_int, _is_range_0_1_float] } } diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 1b93ea2b9740b1cac0139fb5904bf87fea94c8b6..8984b47ed89bcfcb7a3facd6371d68173c40bc82 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -129,7 +129,7 @@ def check_model(model_name, model, model_type): model_type, type(model).__name__) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) def check_numpy_param(arg_name, arg_value): diff --git a/tests/ut/python/detectors/test_region_based_detector.py b/tests/ut/python/detectors/test_region_based_detector.py index 845a99e5159c975adf7e331785dc976c906f0a8b..8d7ae16f7a20ddabf140d1684383ec71c4af3088 100644 --- a/tests/ut/python/detectors/test_region_based_detector.py +++ b/tests/ut/python/detectors/test_region_based_detector.py @@ -84,7 +84,7 @@ def test_value_error(): adv = np.random.rand(4, 4).astype(np.float32) model = Model(Net()) # model should be mindspore model - with pytest.raises(ValueError): + with pytest.raises(TypeError): assert RegionBasedDetector(Net()) with pytest.raises(ValueError):