diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index f5c8a5b73045b8ae71047de6123029b3a8438df6..4e301a2120045d5a107cb1c1ae21bc2893163362 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -3361,24 +3361,14 @@ def switch_case(branch_index, branch_fns, default=None, name=None): helper = LayerHelper('switch_case', **locals()) def _check_args(branch_index, branch_fns, default): - if not isinstance(branch_index, Variable): - raise TypeError( - _error_message("The type", "branch_index", "switch_case", - "Variable", type(branch_index))) - if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]: - raise TypeError( - _error_message("The data type", "branch_index", "switch_case", - "uint8, int32 or int64", - convert_dtype(branch_index.dtype))) + check_variable_and_dtype(branch_index, 'branch_index', + ['uint8', 'int32', 'int64'], 'switch_case') if convert_dtype(branch_index.dtype) != "int64": branch_index = cast(branch_index, "int64") - if not isinstance(branch_fns, (list, tuple, dict)): - raise TypeError( - _error_message("The type", "branch_fns", "switch_case", - "dict, tuple or list", type(branch_fns))) + check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case') branch_fns = branch_fns.items() if isinstance(branch_fns, dict) else branch_fns @@ -3391,7 +3381,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None): if not isinstance(index_fn_pair, tuple): raise TypeError( _error_message("The elements' type", "branch_fns", - "switch_case", "tuple", type(branch_fns))) + "switch_case", tuple, type(branch_fns))) if len(index_fn_pair) != 2: raise TypeError( @@ -3404,7 +3394,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None): if not isinstance(key, int): raise TypeError( _error_message("The key's type", "branch_fns", - "switch_case", "int", type(key))) + "switch_case", int, type(key))) if key in keys_of_fns: raise ValueError(