diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 4e301a2120045d5a107cb1c1ae21bc2893163362..63029d2f1e549750746b0c36af0b267505b94d11 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2173,7 +2173,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): def _error_message(what, arg_name, op_name, right_value, error_value): - error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \ + error_message = "{what} of '{arg_name}' in {op_name} must be " \ "{right_value}, but received: {error_value}.".format( what=what, arg_name=arg_name, @@ -2309,7 +2309,7 @@ class Switch(object): OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` . Member Functions: - case(cond): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed. + case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed. default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed. @@ -2372,6 +2372,10 @@ class Switch(object): if not self.inside_scope: raise ValueError("case should be called inside with") + check_variable_and_dtype( + condition, 'condition', ['bool'], + 'the member function case of fluid.layers.Switch') + if len(self.pre_not_conditions) == 0: cond_block = ConditionalBlock([condition], is_scalar_condition=True) not_cond = logical_not(x=condition) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 2386d1f27cacea7e1ee4675831725401114af2c4..a7fef16814ae4cdcd6e2be10ac1a22d32efa2fe4 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1154,7 +1154,7 @@ def isfinite(x): out = fluid.layers.isfinite(var) """ helper = LayerHelper("isfinite", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = helper.create_variable_for_type_inference(dtype='bool') helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_switch.py b/python/paddle/fluid/tests/unittests/test_switch.py index 2a9c07a889ba5fe24fd1c098729a233cb8fbb16f..b9f3c804ef357aec9b9941b8a80903a3efa44c0f 100644 --- a/python/paddle/fluid/tests/unittests/test_switch.py +++ b/python/paddle/fluid/tests/unittests/test_switch.py @@ -26,7 +26,6 @@ from paddle.fluid.framework import default_startup_program class TestSwitch(unittest.TestCase): def check_switch(self, value): x = layers.fill_constant(shape=[1], dtype='float32', value=value) - zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0) one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0) two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0) @@ -62,5 +61,34 @@ class TestSwitch(unittest.TestCase): self.assertEqual(result, expected_result) +class TestSwitchCaseError(unittest.TestCase): + def test_error(self): + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + cond = layers.fill_constant(shape=[1], dtype='float32', value=0.0) + zero_var = layers.fill_constant( + shape=[1], dtype='float32', value=0.0) + + result = layers.create_global_var( + shape=[1], value=-1.0, dtype='float32', persistable=True) + + # 1. The type of 'condition' in case must be Variable. + def test_condition_type(): + with layers.Switch() as switch: + with switch.case(1): + layers.assign(zero_var, result) + + self.assertRaises(TypeError, test_condition_type) + + # 2. The dtype of 'condition' in case must be 'bool'. + def test_condition_dtype(): + with layers.Switch() as switch: + with switch.case(cond): + layers.assign(zero_var, result) + + self.assertRaises(TypeError, test_condition_dtype) + + if __name__ == '__main__': unittest.main()