diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 6bc9d29a905a1b54a225f91cf2711b55efe48ee8..bf0a2a28364998bacfe43162a337a86938328b9a 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2035,13 +2035,35 @@ def cond(pred, true_fn=None, false_fn=None, name=None): # [ True True True]] """ + if in_dygraph_mode(): + assert isinstance(pred, Variable), "The pred in cond must be Variable" + assert pred.numpy().size == 1, "condition input's numel should be 1" + pred = pred.numpy()[0] + if pred: + if true_fn is not None: + if not callable(true_fn): + raise TypeError( + "The true_fn in cond must be callable, but received {}". + format(type(true_fn).__name__)) + return true_fn() + else: + if false_fn is not None: + if not callable(false_fn): + raise TypeError( + "The false_fn in cond must be callable, but received {}". + format(type(false_fn).__name__)) + return false_fn() + return None + helper = LayerHelper('cond', **locals()) true_output = None false_output = None copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper) if true_fn is not None: if not callable(true_fn): - raise TypeError("The true_fn in cond must be callable") + raise TypeError( + "The true_fn in cond must be callable, but received {}".format( + type(true_fn).__name__)) true_cond_block = ConditionalBlock([pred], is_scalar_condition=True) with true_cond_block.block(): origin_true_output = true_fn() @@ -2050,7 +2072,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None): origin_true_output) if false_fn is not None: if not callable(false_fn): - raise TypeError("The false_fn in cond must be callable") + raise TypeError( + "The false_fn in cond must be callable, but received {}".format( + type(false_fn).__name__)) false_cond_block = ConditionalBlock( [logical_not(pred)], is_scalar_condition=True) with false_cond_block.block(): diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index b3632a5f5f3c6df569d872785c177a5bea36b1ac..7358de2c50b6293b9987abe926e61a93102541c3 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -192,15 +192,11 @@ class TestCondInputOutput(unittest.TestCase): with program_guard(main_program, startup_program): i = fluid.data(name="i", shape=[1], dtype='int32') pred = ((i % 2) == 0) - with self.assertRaises(Exception) as e: + with self.assertRaises(TypeError): out = layers.cond(pred, i, func_return_one_tensor) - self.assertEqual("The true_fn in cond must be callable", - str(e.exception)) - with self.assertRaises(Exception) as e: + with self.assertRaises(TypeError): out = layers.cond(pred, func_return_one_tensor, np.asarray([3])) - self.assertEqual("The false_fn in cond must be callable", - str(e.exception)) with self.assertRaises(Exception) as e: out = layers.cond(pred, func_return_none, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 91b17cc5273f1499b6b19b3bc6cd25b7f313f288..c0d4284dcc131390fa78ef220b3f88a269b69b70 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1527,6 +1527,43 @@ class TestLayer(LayerTest): for i in range(len(static_ret5)): self.assertTrue(dcond5.numpy()[i] == static_ret5[i]) + def test_cond(self): + def less_than_branch(a, b): + return fluid.layers.elementwise_add(a, b) + + def greater_equal_branch(a, b): + return fluid.layers.elementwise_sub(a, b) + + with self.static_graph(): + a = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=0.1) + b = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=0.23) + out = fluid.layers.cond(a >= b, lambda: greater_equal_branch(a, b), + lambda: less_than_branch(a, b)) + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + ret = exe.run(fetch_list=[out]) + static_res = ret[0] + + with self.dynamic_graph(): + a = fluid.dygraph.to_variable(np.array([0.1]).astype('float32')) + b = fluid.dygraph.to_variable(np.array([0.23]).astype('float32')) + out = layers.cond(a < b, lambda: less_than_branch(a, b), + lambda: greater_equal_branch(a, b)) + out2 = layers.cond(a >= b, lambda: greater_equal_branch(a, b), + lambda: less_than_branch(a, b)) + dynamic_res = out.numpy() + dynamic_res2 = out2.numpy() + self.assertTrue(np.array_equal(dynamic_res, dynamic_res2)) + with self.assertRaises(TypeError): + layers.cond(a < b, 'str', 'str') + with self.assertRaises(TypeError): + layers.cond(a >= b, 'str', 'str') + + self.assertTrue(np.array_equal(static_res, dynamic_res)) + def test_crop_tensor(self): with self.static_graph(): x = fluid.layers.data(name="x1", shape=[6, 5, 8])