diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index e5980abea5d1e81e82eabbd1774e5c6bd679ab1b..9123b4b009d1866d7ce9385777e4ddc5ffc0d5e0 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -89,6 +89,67 @@ class TestAPICase(unittest.TestCase): np.testing.assert_allclose(res[3], 2, rtol=1e-05) np.testing.assert_allclose(res[4], 2, rtol=1e-05) + def test_0d_tensor(self): + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + # call fn_1 + out_0 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call fn_2 + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call default fn_3 + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3 + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_2)] + ) + + # no default, call fn_2. but pred_2 is false + out_4 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_2)] + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4] + ) + + np.testing.assert_allclose(res[0], 1, rtol=1e-05) + np.testing.assert_allclose(res[1], 2, rtol=1e-05) + np.testing.assert_allclose(res[2], 3, rtol=1e-05) + np.testing.assert_allclose(res[3], 2, rtol=1e-05) + np.testing.assert_allclose(res[4], 2, rtol=1e-05) + def test_return_var_tuple(self): def fn_1(): return layers.fill_constant( @@ -236,6 +297,106 @@ class TestAPICase_Nested(unittest.TestCase): np.testing.assert_allclose(res[1], 2, rtol=1e-05) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + def test_nested_0d_tensor(self): + def fn_1(x=1): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + ( + var_5 < var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ] + ) + return out + + def fn_2(x=2): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + (var_5 < var_6, partial(fn_1, x=x)), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ] + ) + return out + + def fn_3(): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + (var_5 < var_6, partial(fn_2, x=3)), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=7, + ), + ), + ] + ) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3 + ) + + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3 + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) + + np.testing.assert_allclose(res[0], 1, rtol=1e-05) + np.testing.assert_allclose(res[1], 2, rtol=1e-05) + np.testing.assert_allclose(res[2], 3, rtol=1e-05) + class TestAPICase_Error(unittest.TestCase): def test_error(self): diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 3176ace0a381366b138cdaa99f04602706bd82ec..9769aa8df430e70e46ae84fc3a12ea7221c1abe7 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -68,6 +68,115 @@ class TestCondInputOutput(unittest.TestCase): np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) + def test_return_0d_tensor(self): + """ + pseudocode: + + if 0.23 >= 0.1: + return 2 + else: + return -1 + """ + + paddle.enable_static() + + def true_func(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def false_func(): + return paddle.full(shape=[], dtype='int32', fill_value=-1) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) + pred = paddle.greater_equal(y, x) + out = paddle.static.nn.cond(pred, true_func, false_func) + # out is one tensor + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) + + def test_0d_tensor_as_cond(self): + """ + pseudocode: + + if 0.23 >= 0.1: + return 2 + else: + return -1 + """ + + paddle.enable_static() + + def true_func(): + return paddle.full(shape=[3, 3], dtype='int32', fill_value=2) + + def false_func(): + return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[], dtype='float32', fill_value=0.23) + pred = paddle.greater_equal(y, x) + out = paddle.static.nn.cond(pred, true_func, false_func) + # out is one tensor + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose( + np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05 + ) + + def test_0d_tensor_backward(self): + """ + pseudocode: + + a = -2.0 + if a >= 0: + return a + else: + return -a + """ + + paddle.enable_static() + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + a = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + a.stop_gradient = False + out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) + np.testing.assert_allclose( + np.asarray(ret[0]), np.array(2.0), rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[1]), np.array(-1.0), rtol=1e-05 + ) + def test_return_var_tuple(self): """ pseudocode: @@ -358,6 +467,70 @@ class TestCondNestedControlFlow(unittest.TestCase): self.assertEqual(ret[0][0], expected_ret) self.assertEqual(ret[1][0], expected_a_grad) + def test_cond_inside_cond_0d_tensor(self): + """ + pseudocode: + i = 3.0 + a = 2 * i + if i < 5: + if i >= 3: + return a + 1 + else: + return 1 - a + else: + if i < 8: + return a * 2 + else: + return a / 2 + """ + + paddle.enable_static() + + def less_than_branch(i, a): + return paddle.static.nn.cond( + i >= 3.0, + lambda: a + 1, + lambda: 1 - a, + ) + + def greater_equal_branch(i, a): + return paddle.static.nn.cond( + i < 8.0, + lambda: a * 2, + lambda: a / 2, + ) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + i = paddle.full(fill_value=3.0, shape=[], dtype='float32') + i.stop_gradient = False + a = 2.0 * i + out = paddle.static.nn.cond( + i < 5.0, + lambda: less_than_branch(i, a), + lambda: greater_equal_branch(i, a), + ) + mean = paddle.mean(out) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret = exe.run( + main_program, + fetch_list=[out.name, i.grad_name], + ) + np.testing.assert_allclose( + np.asarray(ret[0]), np.array(7.0), rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[1]), np.array(2.0), rtol=1e-05 + ) + def test_cond_op_in_condition(self): paddle.enable_static() main_program = fluid.Program() diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 119b5ac285f7347632d8b52c1a5563ed1996adc7..2ddbd0f7ff051e246d3e6e3bb54cc2bddab25e4a 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -114,6 +114,93 @@ class TestAPISwitchCase(unittest.TestCase): err_msg='result is {} but answer is {}'.format(res[0], 2), ) + def test_0d_tensor(self): + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = paddle.full(shape=[], dtype='int32', fill_value=1) + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + + # call fn_1 + out_0 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3} + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3) + ) + + # call default fn_3 + out_2 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=((1, fn_1), (2, fn_2)), + default=fn_3, + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)] + ) + + # no default, call fn_2 but branch_index is 5 + out_4 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)], + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4] + ) + + np.testing.assert_allclose( + res[0], + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 1), + ) + np.testing.assert_allclose( + res[1], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + np.testing.assert_allclose( + res[2], + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 3), + ) + np.testing.assert_allclose( + res[3], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + np.testing.assert_allclose( + res[4], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + def test_return_var_tuple(self): def fn_1(): return layers.fill_constant( @@ -257,6 +344,101 @@ class TestAPISwitchCase_Nested(unittest.TestCase): err_msg='result is {} but answer is {}'.format(res[2], 3), ) + def test_nested_switch_0d_tensor(self): + def fn_1(x=1): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=x), + branch_fns={ + 1: partial( + paddle.full, shape=[], dtype='int32', fill_value=1 + ), + x: partial( + paddle.full, shape=[], dtype='int32', fill_value=x + ), + }, + ) + return out + + def fn_2(x=2): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=2), + branch_fns={ + 1: partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=1, + ), + 2: partial(fn_1, x=x), + }, + ) + return out + + def fn_3(): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=3), + branch_fns={ + 1: partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=1, + ), + 3: partial(fn_2, x=3), + }, + ) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8') + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_3 = paddle.full(shape=[], dtype='int64', fill_value=3) + + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + out_2 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + out_3 = paddle.static.nn.switch_case( + branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, + feed={"index_1": np.array([1], dtype="uint8")}, + fetch_list=[out_1, out_2, out_3], + ) + + np.testing.assert_allclose( + res[0], + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 1), + ) + np.testing.assert_allclose( + res[1], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[1], 2), + ) + np.testing.assert_allclose( + res[2], + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[2], 3), + ) + # test TypeError and ValueError of api switch_case class TestAPISwitchCase_Error(unittest.TestCase):