未验证 提交 cd348dc4 编写于 作者: L liym27 提交者: GitHub

API (case) error message enhancement. test=develop (#23428)

上级 4fe9ca69
......@@ -26,7 +26,7 @@ import numpy
import warnings
import six
from functools import reduce, partial
from ..data_feeder import convert_dtype, check_variable_and_dtype
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type
from ... import compat as cpt
from ..backward import _infer_var_data_type_shape_
......@@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None):
'''
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
'''
if not isinstance(pred_fn_pairs, (list, tuple)):
raise TypeError(
_error_message("The type", "pred_fn_pairs", "case",
"list or tuple", type(pred_fn_pairs)))
check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')
for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple):
raise TypeError(
_error_message("The elements' type", "pred_fn_pairs",
"case", "tuple", type(pred_fn)))
"case", tuple, type(pred_fn)))
if len(pred_fn) != 2:
raise TypeError(
_error_message("The tuple's size", "pred_fn_pairs", "case",
......
......@@ -187,7 +187,7 @@ class TestAPICase_Error(unittest.TestCase):
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true
# The type of 'pred_fn_pairs' in case must be list or tuple
# The type of 'pred_fn_pairs' in case must be list or tuple
def type_error_pred_fn_pairs():
layers.case(pred_fn_pairs=1, default=fn_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册