diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 5356db4123b075a67f78502edb29f15ba8e314d2..f5c8a5b73045b8ae71047de6123029b3a8438df6 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -26,7 +26,7 @@ import numpy import warnings import six from functools import reduce, partial -from ..data_feeder import convert_dtype, check_variable_and_dtype +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type from ... import compat as cpt from ..backward import _infer_var_data_type_shape_ @@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None): ''' Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default. ''' - if not isinstance(pred_fn_pairs, (list, tuple)): - raise TypeError( - _error_message("The type", "pred_fn_pairs", "case", - "list or tuple", type(pred_fn_pairs))) + check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case') for pred_fn in pred_fn_pairs: if not isinstance(pred_fn, tuple): raise TypeError( _error_message("The elements' type", "pred_fn_pairs", - "case", "tuple", type(pred_fn))) + "case", tuple, type(pred_fn))) if len(pred_fn) != 2: raise TypeError( _error_message("The tuple's size", "pred_fn_pairs", "case", diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index 722d5ef0862905a717ce6b20ac62b83cb4e7b1eb..6391435cc80955513995b20bbb16ca6fdc9b38e2 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -187,7 +187,7 @@ class TestAPICase_Error(unittest.TestCase): z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) pred_1 = layers.less_than(z, x) # true - # The type of 'pred_fn_pairs' in case must be list or tuple + # The type of 'pred_fn_pairs' in case must be list or tuple def type_error_pred_fn_pairs(): layers.case(pred_fn_pairs=1, default=fn_1)