未验证 提交 03deb41d 编写于 作者: L liym27 提交者: GitHub

API (switch_case) error message enhancement. test=develop (#23429)

上级 cd348dc4
...@@ -3361,24 +3361,14 @@ def switch_case(branch_index, branch_fns, default=None, name=None): ...@@ -3361,24 +3361,14 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
helper = LayerHelper('switch_case', **locals()) helper = LayerHelper('switch_case', **locals())
def _check_args(branch_index, branch_fns, default): def _check_args(branch_index, branch_fns, default):
if not isinstance(branch_index, Variable):
raise TypeError(
_error_message("The type", "branch_index", "switch_case",
"Variable", type(branch_index)))
if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]: check_variable_and_dtype(branch_index, 'branch_index',
raise TypeError( ['uint8', 'int32', 'int64'], 'switch_case')
_error_message("The data type", "branch_index", "switch_case",
"uint8, int32 or int64",
convert_dtype(branch_index.dtype)))
if convert_dtype(branch_index.dtype) != "int64": if convert_dtype(branch_index.dtype) != "int64":
branch_index = cast(branch_index, "int64") branch_index = cast(branch_index, "int64")
if not isinstance(branch_fns, (list, tuple, dict)): check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
raise TypeError(
_error_message("The type", "branch_fns", "switch_case",
"dict, tuple or list", type(branch_fns)))
branch_fns = branch_fns.items() if isinstance(branch_fns, branch_fns = branch_fns.items() if isinstance(branch_fns,
dict) else branch_fns dict) else branch_fns
...@@ -3391,7 +3381,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None): ...@@ -3391,7 +3381,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
if not isinstance(index_fn_pair, tuple): if not isinstance(index_fn_pair, tuple):
raise TypeError( raise TypeError(
_error_message("The elements' type", "branch_fns", _error_message("The elements' type", "branch_fns",
"switch_case", "tuple", type(branch_fns))) "switch_case", tuple, type(branch_fns)))
if len(index_fn_pair) != 2: if len(index_fn_pair) != 2:
raise TypeError( raise TypeError(
...@@ -3404,7 +3394,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None): ...@@ -3404,7 +3394,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
if not isinstance(key, int): if not isinstance(key, int):
raise TypeError( raise TypeError(
_error_message("The key's type", "branch_fns", _error_message("The key's type", "branch_fns",
"switch_case", "int", type(key))) "switch_case", int, type(key)))
if key in keys_of_fns: if key in keys_of_fns:
raise ValueError( raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册