未验证 提交 03deb41d 编写于 作者: L liym27 提交者: GitHub

API (switch_case) error message enhancement. test=develop (#23429)

上级 cd348dc4
......@@ -3361,24 +3361,14 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
helper = LayerHelper('switch_case', **locals())
def _check_args(branch_index, branch_fns, default):
if not isinstance(branch_index, Variable):
raise TypeError(
_error_message("The type", "branch_index", "switch_case",
"Variable", type(branch_index)))
if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]:
raise TypeError(
_error_message("The data type", "branch_index", "switch_case",
"uint8, int32 or int64",
convert_dtype(branch_index.dtype)))
check_variable_and_dtype(branch_index, 'branch_index',
['uint8', 'int32', 'int64'], 'switch_case')
if convert_dtype(branch_index.dtype) != "int64":
branch_index = cast(branch_index, "int64")
if not isinstance(branch_fns, (list, tuple, dict)):
raise TypeError(
_error_message("The type", "branch_fns", "switch_case",
"dict, tuple or list", type(branch_fns)))
check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
branch_fns = branch_fns.items() if isinstance(branch_fns,
dict) else branch_fns
......@@ -3391,7 +3381,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
if not isinstance(index_fn_pair, tuple):
raise TypeError(
_error_message("The elements' type", "branch_fns",
"switch_case", "tuple", type(branch_fns)))
"switch_case", tuple, type(branch_fns)))
if len(index_fn_pair) != 2:
raise TypeError(
......@@ -3404,7 +3394,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
if not isinstance(key, int):
raise TypeError(
_error_message("The key's type", "branch_fns",
"switch_case", "int", type(key)))
"switch_case", int, type(key)))
if key in keys_of_fns:
raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册