From 067134f1b33002f62b5a588b8cdaa70c02bf2c7f Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 8 Apr 2020 22:59:53 +0800 Subject: [PATCH] API (Switch) error message enhancement. test=develop (#23459) * API (Switch) error message enhancement. * fix bug: dtype of out in api isfinite is set incorrectly. The dtype should be bool. --- python/paddle/fluid/layers/control_flow.py | 8 +++-- python/paddle/fluid/layers/tensor.py | 2 +- .../fluid/tests/unittests/test_switch.py | 30 ++++++++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 4e301a21200..63029d2f1e5 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2173,7 +2173,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): def _error_message(what, arg_name, op_name, right_value, error_value): - error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \ + error_message = "{what} of '{arg_name}' in {op_name} must be " \ "{right_value}, but received: {error_value}.".format( what=what, arg_name=arg_name, @@ -2309,7 +2309,7 @@ class Switch(object): OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` . Member Functions: - case(cond): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed. + case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed. default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed. @@ -2372,6 +2372,10 @@ class Switch(object): if not self.inside_scope: raise ValueError("case should be called inside with") + check_variable_and_dtype( + condition, 'condition', ['bool'], + 'the member function case of fluid.layers.Switch') + if len(self.pre_not_conditions) == 0: cond_block = ConditionalBlock([condition], is_scalar_condition=True) not_cond = logical_not(x=condition) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 2386d1f27ca..a7fef16814a 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1154,7 +1154,7 @@ def isfinite(x): out = fluid.layers.isfinite(var) """ helper = LayerHelper("isfinite", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = helper.create_variable_for_type_inference(dtype='bool') helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_switch.py b/python/paddle/fluid/tests/unittests/test_switch.py index 2a9c07a889b..b9f3c804ef3 100644 --- a/python/paddle/fluid/tests/unittests/test_switch.py +++ b/python/paddle/fluid/tests/unittests/test_switch.py @@ -26,7 +26,6 @@ from paddle.fluid.framework import default_startup_program class TestSwitch(unittest.TestCase): def check_switch(self, value): x = layers.fill_constant(shape=[1], dtype='float32', value=value) - zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0) one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0) two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0) @@ -62,5 +61,34 @@ class TestSwitch(unittest.TestCase): self.assertEqual(result, expected_result) +class TestSwitchCaseError(unittest.TestCase): + def test_error(self): + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + cond = layers.fill_constant(shape=[1], dtype='float32', value=0.0) + zero_var = layers.fill_constant( + shape=[1], dtype='float32', value=0.0) + + result = layers.create_global_var( + shape=[1], value=-1.0, dtype='float32', persistable=True) + + # 1. The type of 'condition' in case must be Variable. + def test_condition_type(): + with layers.Switch() as switch: + with switch.case(1): + layers.assign(zero_var, result) + + self.assertRaises(TypeError, test_condition_type) + + # 2. The dtype of 'condition' in case must be 'bool'. + def test_condition_dtype(): + with layers.Switch() as switch: + with switch.case(cond): + layers.assign(zero_var, result) + + self.assertRaises(TypeError, test_condition_dtype) + + if __name__ == '__main__': unittest.main() -- GitLab