提交 2ffdf100 编写于 作者: L liuluobin

Adjust config check. Modify the exception type to TypeError in function check_model.

上级 2ded64d6
......@@ -24,45 +24,39 @@ TAG = "check_config"
def _is_positive_int(item):
"""
Verify that the value is a positive integer.
"""
if not isinstance(item, int) or item <= 0:
"""Verify that the value is a positive integer."""
if not isinstance(item, int):
return False
return True
return item > 0
def _is_non_negative_int(item):
"""
Verify that the value is a non-negative integer.
"""
if not isinstance(item, int) or item < 0:
"""Verify that the value is a non-negative integer."""
if not isinstance(item, int):
return False
return True
return item >= 0
def _is_positive_float(item):
"""
Verify that value is a positive number.
"""
if not isinstance(item, (int, float)) or item <= 0:
"""Verify that value is a positive number."""
if not isinstance(item, (int, float)):
return False
return True
return item > 0
def _is_non_negative_float(item):
"""
Verify that value is a non-negative number.
"""
if not isinstance(item, (int, float)) or item < 0:
"""Verify that value is a non-negative number."""
if not isinstance(item, (int, float)):
return False
return True
return item >= 0
def _is_range_0_1_float(item):
if not isinstance(item, (int, float)):
return False
return 0 <= item < 1
def _is_positive_int_tuple(item):
"""
Verify that the input parameter is a positive integer tuple.
"""
"""Verify that the input parameter is a positive integer tuple."""
if not isinstance(item, tuple):
return False
for i in item:
......@@ -72,21 +66,29 @@ def _is_positive_int_tuple(item):
def _is_dict(item):
"""
Check whether the type is dict.
"""
"""Check whether the type is dict."""
return isinstance(item, dict)
def _is_list(item):
"""Check whether the type is list"""
return isinstance(item, list)
def _is_str(item):
"""Check whether the type is str."""
return isinstance(item, str)
_VALID_CONFIG_CHECKLIST = {
"knn": {
"n_neighbors": [_is_positive_int],
"weights": [{"uniform", "distance"}],
"weights": [{"uniform", "distance"}, callable],
"algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}],
"leaf_size": [_is_positive_int],
"p": [_is_positive_int],
"metric": None,
"metric_params": None,
"metric": [_is_str, callable],
"metric_params": [_is_dict, {None}]
},
"lr": {
"penalty": [{"l1", "l2", "elasticnet", "none"}],
......@@ -102,7 +104,7 @@ _VALID_CONFIG_CHECKLIST = {
"mlp": {
"hidden_layer_sizes": [_is_positive_int_tuple],
"activation": [{"identity", "logistic", "tanh", "relu"}],
"solver": {"lbfgs", "sgd", "adam"},
"solver": [{"lbfgs", "sgd", "adam"}],
"alpha": [_is_positive_float],
"batch_size": [{"auto"}, _is_positive_int],
"learning_rate": [{"constant", "invscaling", "adaptive"}],
......@@ -117,9 +119,9 @@ _VALID_CONFIG_CHECKLIST = {
"momentum": [_is_positive_float],
"nesterovs_momentum": [{True, False}],
"early_stopping": [{True, False}],
"validation_fraction": [_is_positive_float],
"beta_1": [_is_positive_float],
"beta_2": [_is_positive_float],
"validation_fraction": [_is_range_0_1_float],
"beta_1": [_is_range_0_1_float],
"beta_2": [_is_range_0_1_float],
"epsilon": [_is_positive_float],
"n_iter_no_change": [_is_positive_int],
"max_fun": [_is_positive_int]
......@@ -133,7 +135,7 @@ _VALID_CONFIG_CHECKLIST = {
"min_weight_fraction_leaf": [_is_non_negative_float],
"max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float],
"max_leaf_nodes": [_is_positive_int, {None}],
"min_impurity_decrease": {_is_non_negative_float},
"min_impurity_decrease": [_is_non_negative_float],
"min_impurity_split": [{None}, _is_positive_float],
"bootstrap": [{True, False}],
"oob_scroe": [{True, False}],
......@@ -141,9 +143,9 @@ _VALID_CONFIG_CHECKLIST = {
"random_state": None,
"verbose": [_is_non_negative_int],
"warm_start": [{True, False}],
"class_weight": None,
"class_weight": [{"balanced", "balanced_subsample"}, _is_dict, _is_list],
"ccp_alpha": [_is_non_negative_float],
"max_samples": [_is_positive_float]
"max_samples": [{None}, _is_positive_int, _is_range_0_1_float]
}
}
......
......@@ -129,7 +129,7 @@ def check_model(model_name, model, model_type):
model_type,
type(model).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
def check_numpy_param(arg_name, arg_value):
......
......@@ -84,7 +84,7 @@ def test_value_error():
adv = np.random.rand(4, 4).astype(np.float32)
model = Model(Net())
# model should be mindspore model
with pytest.raises(ValueError):
with pytest.raises(TypeError):
assert RegionBasedDetector(Net())
with pytest.raises(ValueError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册