From 773767273cf5d3c94fc52269970ae6c6f2c0785e Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 18 Jan 2023 08:51:47 +0800 Subject: [PATCH] Enrich 0d Tensor Dygraph and Shape Unit Test for `case` and `switch_case` (#49889) Followed PR https://github.com/PaddlePaddle/Paddle/pull/49842 , added Digraph and Shape unit test for `case` and `switch_case`. This PR only contained test changes because `case` and `switch_case` call `cond`. The PR https://github.com/PaddlePaddle/Paddle/pull/49842 has already solved the 0d tensor support. --- .../paddle/fluid/tests/unittests/test_case.py | 95 ++++++++++++ .../fluid/tests/unittests/test_switch_case.py | 138 +++++++++++++++++- 2 files changed, 225 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index 9123b4b009d..675b51cf0a0 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer +from paddle.fluid.backward import append_backward from paddle.fluid.framework import Program, program_guard paddle.enable_static() @@ -145,10 +146,101 @@ class TestAPICase(unittest.TestCase): ) np.testing.assert_allclose(res[0], 1, rtol=1e-05) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose(res[1], 2, rtol=1e-05) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + self.assertEqual(res[2].shape, ()) np.testing.assert_allclose(res[3], 2, rtol=1e-05) + self.assertEqual(res[3].shape, ()) np.testing.assert_allclose(res[4], 2, rtol=1e-05) + self.assertEqual(res[4].shape, ()) + + def test_0d_tensor_backward(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + x.stop_gradient = False + pred = paddle.full(shape=[], dtype='bool', fill_value=0) + # pred is False, so out = -x + out = paddle.static.nn.case( + pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x + ) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + np.testing.assert_allclose( + np.asarray(res[0]), np.array(2.0), rtol=1e-05 + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose( + np.asarray(res[1]), np.array(-1.0), rtol=1e-05 + ) + self.assertEqual(res[1].shape, ()) + + def test_0d_tensor_dygraph(self): + paddle.disable_static() + + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + # call fn_1 + out_0 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call fn_2 + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call default fn_3 + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3 + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_2)] + ) + + # no default, call fn_2. but pred_2 is false + out_4 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_2)] + ) + + np.testing.assert_allclose(out_0, 1, rtol=1e-05) + self.assertEqual(out_0.shape, []) + np.testing.assert_allclose(out_1, 2, rtol=1e-05) + self.assertEqual(out_1.shape, []) + np.testing.assert_allclose(out_2, 3, rtol=1e-05) + self.assertEqual(out_2.shape, []) + np.testing.assert_allclose(out_3, 2, rtol=1e-05) + self.assertEqual(out_3.shape, []) + np.testing.assert_allclose(out_4, 2, rtol=1e-05) + self.assertEqual(out_4.shape, []) + + paddle.enable_static() def test_return_var_tuple(self): def fn_1(): @@ -394,8 +486,11 @@ class TestAPICase_Nested(unittest.TestCase): res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) np.testing.assert_allclose(res[0], 1, rtol=1e-05) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose(res[1], 2, rtol=1e-05) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + self.assertEqual(res[2].shape, ()) class TestAPICase_Error(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 2ddbd0f7ff0..3fad3bdfd0c 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -21,6 +21,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers +from paddle.fluid.backward import append_backward from paddle.fluid.framework import Program, program_guard paddle.enable_static() @@ -93,25 +94,25 @@ class TestAPISwitchCase(unittest.TestCase): res[1], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[1], 2), ) np.testing.assert_allclose( res[2], 3, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 3), + err_msg='result is {} but answer is {}'.format(res[2], 3), ) np.testing.assert_allclose( res[3], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[3], 2), ) np.testing.assert_allclose( res[4], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[4], 2), ) def test_0d_tensor(self): @@ -176,30 +177,148 @@ class TestAPISwitchCase(unittest.TestCase): rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[0], 1), ) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose( res[1], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[1], 2), ) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose( res[2], 3, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 3), + err_msg='result is {} but answer is {}'.format(res[2], 3), ) + self.assertEqual(res[2].shape, ()) np.testing.assert_allclose( res[3], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[3], 2), ) + self.assertEqual(res[3].shape, ()) np.testing.assert_allclose( res[4], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[4], 2), ) + self.assertEqual(res[4].shape, ()) + + def test_0d_tensor_backward(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + x.stop_gradient = False + pred = paddle.full(shape=[], dtype='int32', fill_value=2) + # pred is 2, so out = 2 * x + out = paddle.static.nn.switch_case( + branch_index=pred, + branch_fns=[(1, lambda: x), (2, lambda: 2 * x)], + default=lambda: -x, + ) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + np.testing.assert_allclose( + np.asarray(res[0]), np.array(-4.0), rtol=1e-05 + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose( + np.asarray(res[1]), np.array(2.0), rtol=1e-05 + ) + self.assertEqual(res[1].shape, ()) + + def test_0d_tensor_dygraph(self): + paddle.disable_static() + + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + index_1 = paddle.full(shape=[], dtype='int32', fill_value=1) + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + + # call fn_1 + out_0 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3} + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3) + ) + + # call default fn_3 + out_2 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=((1, fn_1), (2, fn_2)), + default=fn_3, + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)] + ) + + # no default, call fn_2 but branch_index is 5 + out_4 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)], + ) + np.testing.assert_allclose( + out_0, + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_0, 1), + ) + self.assertEqual(out_0.shape, []) + np.testing.assert_allclose( + out_1, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_1, 2), + ) + self.assertEqual(out_1.shape, []) + np.testing.assert_allclose( + out_2, + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_2, 3), + ) + self.assertEqual(out_2.shape, []) + np.testing.assert_allclose( + out_3, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_3, 2), + ) + self.assertEqual(out_3.shape, []) + np.testing.assert_allclose( + out_4, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_4, 2), + ) + self.assertEqual(out_4.shape, []) + + paddle.enable_static() def test_return_var_tuple(self): def fn_1(): @@ -426,18 +545,21 @@ class TestAPISwitchCase_Nested(unittest.TestCase): rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[0], 1), ) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose( res[1], 2, rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[1], 2), ) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose( res[2], 3, rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[2], 3), ) + self.assertEqual(res[2].shape, ()) # test TypeError and ValueError of api switch_case -- GitLab