未验证 提交 067134f1 编写于 作者: L liym27 提交者: GitHub

API (Switch) error message enhancement. test=develop (#23459)

* API (Switch) error message enhancement. 

* fix bug: dtype of out in api isfinite is set incorrectly. The dtype should be bool. 
上级 dc225ed2
......@@ -2173,7 +2173,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \
error_message = "{what} of '{arg_name}' in {op_name} must be " \
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
......@@ -2309,7 +2309,7 @@ class Switch(object):
OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .
Member Functions:
case(cond): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.
......@@ -2372,6 +2372,10 @@ class Switch(object):
if not self.inside_scope:
raise ValueError("case should be called inside with")
check_variable_and_dtype(
condition, 'condition', ['bool'],
'the member function case of fluid.layers.Switch')
if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
not_cond = logical_not(x=condition)
......
......@@ -1154,7 +1154,7 @@ def isfinite(x):
out = fluid.layers.isfinite(var)
"""
helper = LayerHelper("isfinite", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype='bool')
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out
......
......@@ -26,7 +26,6 @@ from paddle.fluid.framework import default_startup_program
class TestSwitch(unittest.TestCase):
def check_switch(self, value):
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
......@@ -62,5 +61,34 @@ class TestSwitch(unittest.TestCase):
self.assertEqual(result, expected_result)
class TestSwitchCaseError(unittest.TestCase):
def test_error(self):
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
cond = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
zero_var = layers.fill_constant(
shape=[1], dtype='float32', value=0.0)
result = layers.create_global_var(
shape=[1], value=-1.0, dtype='float32', persistable=True)
# 1. The type of 'condition' in case must be Variable.
def test_condition_type():
with layers.Switch() as switch:
with switch.case(1):
layers.assign(zero_var, result)
self.assertRaises(TypeError, test_condition_type)
# 2. The dtype of 'condition' in case must be 'bool'.
def test_condition_dtype():
with layers.Switch() as switch:
with switch.case(cond):
layers.assign(zero_var, result)
self.assertRaises(TypeError, test_condition_dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册