From cd348dc467ca7848b528edb86308146f829ca122 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 3 Apr 2020 17:23:16 +0800 Subject: [PATCH] API (case) error message enhancement. test=develop (#23428) --- python/paddle/fluid/layers/control_flow.py | 9 +++------ python/paddle/fluid/tests/unittests/test_case.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 5356db4123b..f5c8a5b7304 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -26,7 +26,7 @@ import numpy import warnings import six from functools import reduce, partial -from ..data_feeder import convert_dtype, check_variable_and_dtype +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type from ... import compat as cpt from ..backward import _infer_var_data_type_shape_ @@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None): ''' Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default. ''' - if not isinstance(pred_fn_pairs, (list, tuple)): - raise TypeError( - _error_message("The type", "pred_fn_pairs", "case", - "list or tuple", type(pred_fn_pairs))) + check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case') for pred_fn in pred_fn_pairs: if not isinstance(pred_fn, tuple): raise TypeError( _error_message("The elements' type", "pred_fn_pairs", - "case", "tuple", type(pred_fn))) + "case", tuple, type(pred_fn))) if len(pred_fn) != 2: raise TypeError( _error_message("The tuple's size", "pred_fn_pairs", "case", diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index 722d5ef0862..6391435cc80 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -187,7 +187,7 @@ class TestAPICase_Error(unittest.TestCase): z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) pred_1 = layers.less_than(z, x) # true - # The type of 'pred_fn_pairs' in case must be list or tuple + # The type of 'pred_fn_pairs' in case must be list or tuple def type_error_pred_fn_pairs(): layers.case(pred_fn_pairs=1, default=fn_1) -- GitLab