提交 2ffdf100 编写于 作者: L liuluobin

Adjust config check. Modify the exception type to TypeError in function check_model.

上级 2ded64d6
...@@ -24,45 +24,39 @@ TAG = "check_config" ...@@ -24,45 +24,39 @@ TAG = "check_config"
def _is_positive_int(item): def _is_positive_int(item):
""" """Verify that the value is a positive integer."""
Verify that the value is a positive integer. if not isinstance(item, int):
"""
if not isinstance(item, int) or item <= 0:
return False return False
return True return item > 0
def _is_non_negative_int(item): def _is_non_negative_int(item):
""" """Verify that the value is a non-negative integer."""
Verify that the value is a non-negative integer. if not isinstance(item, int):
"""
if not isinstance(item, int) or item < 0:
return False return False
return True return item >= 0
def _is_positive_float(item): def _is_positive_float(item):
""" """Verify that value is a positive number."""
Verify that value is a positive number. if not isinstance(item, (int, float)):
"""
if not isinstance(item, (int, float)) or item <= 0:
return False return False
return True return item > 0
def _is_non_negative_float(item): def _is_non_negative_float(item):
""" """Verify that value is a non-negative number."""
Verify that value is a non-negative number. if not isinstance(item, (int, float)):
"""
if not isinstance(item, (int, float)) or item < 0:
return False return False
return True return item >= 0
def _is_range_0_1_float(item):
if not isinstance(item, (int, float)):
return False
return 0 <= item < 1
def _is_positive_int_tuple(item): def _is_positive_int_tuple(item):
""" """Verify that the input parameter is a positive integer tuple."""
Verify that the input parameter is a positive integer tuple.
"""
if not isinstance(item, tuple): if not isinstance(item, tuple):
return False return False
for i in item: for i in item:
...@@ -72,21 +66,29 @@ def _is_positive_int_tuple(item): ...@@ -72,21 +66,29 @@ def _is_positive_int_tuple(item):
def _is_dict(item): def _is_dict(item):
""" """Check whether the type is dict."""
Check whether the type is dict.
"""
return isinstance(item, dict) return isinstance(item, dict)
def _is_list(item):
"""Check whether the type is list"""
return isinstance(item, list)
def _is_str(item):
"""Check whether the type is str."""
return isinstance(item, str)
_VALID_CONFIG_CHECKLIST = { _VALID_CONFIG_CHECKLIST = {
"knn": { "knn": {
"n_neighbors": [_is_positive_int], "n_neighbors": [_is_positive_int],
"weights": [{"uniform", "distance"}], "weights": [{"uniform", "distance"}, callable],
"algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}], "algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}],
"leaf_size": [_is_positive_int], "leaf_size": [_is_positive_int],
"p": [_is_positive_int], "p": [_is_positive_int],
"metric": None, "metric": [_is_str, callable],
"metric_params": None, "metric_params": [_is_dict, {None}]
}, },
"lr": { "lr": {
"penalty": [{"l1", "l2", "elasticnet", "none"}], "penalty": [{"l1", "l2", "elasticnet", "none"}],
...@@ -102,7 +104,7 @@ _VALID_CONFIG_CHECKLIST = { ...@@ -102,7 +104,7 @@ _VALID_CONFIG_CHECKLIST = {
"mlp": { "mlp": {
"hidden_layer_sizes": [_is_positive_int_tuple], "hidden_layer_sizes": [_is_positive_int_tuple],
"activation": [{"identity", "logistic", "tanh", "relu"}], "activation": [{"identity", "logistic", "tanh", "relu"}],
"solver": {"lbfgs", "sgd", "adam"}, "solver": [{"lbfgs", "sgd", "adam"}],
"alpha": [_is_positive_float], "alpha": [_is_positive_float],
"batch_size": [{"auto"}, _is_positive_int], "batch_size": [{"auto"}, _is_positive_int],
"learning_rate": [{"constant", "invscaling", "adaptive"}], "learning_rate": [{"constant", "invscaling", "adaptive"}],
...@@ -117,9 +119,9 @@ _VALID_CONFIG_CHECKLIST = { ...@@ -117,9 +119,9 @@ _VALID_CONFIG_CHECKLIST = {
"momentum": [_is_positive_float], "momentum": [_is_positive_float],
"nesterovs_momentum": [{True, False}], "nesterovs_momentum": [{True, False}],
"early_stopping": [{True, False}], "early_stopping": [{True, False}],
"validation_fraction": [_is_positive_float], "validation_fraction": [_is_range_0_1_float],
"beta_1": [_is_positive_float], "beta_1": [_is_range_0_1_float],
"beta_2": [_is_positive_float], "beta_2": [_is_range_0_1_float],
"epsilon": [_is_positive_float], "epsilon": [_is_positive_float],
"n_iter_no_change": [_is_positive_int], "n_iter_no_change": [_is_positive_int],
"max_fun": [_is_positive_int] "max_fun": [_is_positive_int]
...@@ -133,7 +135,7 @@ _VALID_CONFIG_CHECKLIST = { ...@@ -133,7 +135,7 @@ _VALID_CONFIG_CHECKLIST = {
"min_weight_fraction_leaf": [_is_non_negative_float], "min_weight_fraction_leaf": [_is_non_negative_float],
"max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float], "max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float],
"max_leaf_nodes": [_is_positive_int, {None}], "max_leaf_nodes": [_is_positive_int, {None}],
"min_impurity_decrease": {_is_non_negative_float}, "min_impurity_decrease": [_is_non_negative_float],
"min_impurity_split": [{None}, _is_positive_float], "min_impurity_split": [{None}, _is_positive_float],
"bootstrap": [{True, False}], "bootstrap": [{True, False}],
"oob_scroe": [{True, False}], "oob_scroe": [{True, False}],
...@@ -141,9 +143,9 @@ _VALID_CONFIG_CHECKLIST = { ...@@ -141,9 +143,9 @@ _VALID_CONFIG_CHECKLIST = {
"random_state": None, "random_state": None,
"verbose": [_is_non_negative_int], "verbose": [_is_non_negative_int],
"warm_start": [{True, False}], "warm_start": [{True, False}],
"class_weight": None, "class_weight": [{"balanced", "balanced_subsample"}, _is_dict, _is_list],
"ccp_alpha": [_is_non_negative_float], "ccp_alpha": [_is_non_negative_float],
"max_samples": [_is_positive_float] "max_samples": [{None}, _is_positive_int, _is_range_0_1_float]
} }
} }
......
...@@ -129,7 +129,7 @@ def check_model(model_name, model, model_type): ...@@ -129,7 +129,7 @@ def check_model(model_name, model, model_type):
model_type, model_type,
type(model).__name__) type(model).__name__)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg) raise TypeError(msg)
def check_numpy_param(arg_name, arg_value): def check_numpy_param(arg_name, arg_value):
......
...@@ -84,7 +84,7 @@ def test_value_error(): ...@@ -84,7 +84,7 @@ def test_value_error():
adv = np.random.rand(4, 4).astype(np.float32) adv = np.random.rand(4, 4).astype(np.float32)
model = Model(Net()) model = Model(Net())
# model should be mindspore model # model should be mindspore model
with pytest.raises(ValueError): with pytest.raises(TypeError):
assert RegionBasedDetector(Net()) assert RegionBasedDetector(Net())
with pytest.raises(ValueError): with pytest.raises(ValueError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册