未验证 提交 cd348dc4 编写于 作者: L liym27 提交者: GitHub

API (case) error message enhancement. test=develop (#23428)

上级 4fe9ca69
...@@ -26,7 +26,7 @@ import numpy ...@@ -26,7 +26,7 @@ import numpy
import warnings import warnings
import six import six
from functools import reduce, partial from functools import reduce, partial
from ..data_feeder import convert_dtype, check_variable_and_dtype from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type
from ... import compat as cpt from ... import compat as cpt
from ..backward import _infer_var_data_type_shape_ from ..backward import _infer_var_data_type_shape_
...@@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None): ...@@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None):
''' '''
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default. Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
''' '''
if not isinstance(pred_fn_pairs, (list, tuple)): check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')
raise TypeError(
_error_message("The type", "pred_fn_pairs", "case",
"list or tuple", type(pred_fn_pairs)))
for pred_fn in pred_fn_pairs: for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple): if not isinstance(pred_fn, tuple):
raise TypeError( raise TypeError(
_error_message("The elements' type", "pred_fn_pairs", _error_message("The elements' type", "pred_fn_pairs",
"case", "tuple", type(pred_fn))) "case", tuple, type(pred_fn)))
if len(pred_fn) != 2: if len(pred_fn) != 2:
raise TypeError( raise TypeError(
_error_message("The tuple's size", "pred_fn_pairs", "case", _error_message("The tuple's size", "pred_fn_pairs", "case",
......
...@@ -187,7 +187,7 @@ class TestAPICase_Error(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestAPICase_Error(unittest.TestCase):
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true pred_1 = layers.less_than(z, x) # true
# The type of 'pred_fn_pairs' in case must be list or tuple # The type of 'pred_fn_pairs' in case must be list or tuple
def type_error_pred_fn_pairs(): def type_error_pred_fn_pairs():
layers.case(pred_fn_pairs=1, default=fn_1) layers.case(pred_fn_pairs=1, default=fn_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册