From 923f24589e13bcfdc8a411182f39ca111c8e75a1 Mon Sep 17 00:00:00 2001 From: xysheng-baidu <121540080+xysheng-baidu@users.noreply.github.com> Date: Tue, 10 Jan 2023 12:44:29 +0800 Subject: [PATCH] [Zero_Dim][unittest] add repeat_interleave unittest for zero_dim (#49596) --- .../unittests/test_repeat_interleave_op.py | 23 ++++++++++ .../tests/unittests/test_zero_dim_tensor.py | 46 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py index 093cb17b63..6b602fa741 100644 --- a/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py +++ b/python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py @@ -104,6 +104,7 @@ class TestRepeatInterleaveOp2(OpTest): class TestIndexSelectAPI(unittest.TestCase): def input_data(self): + self.data_zero_dim_x = np.array(0.5) self.data_x = np.array( [ [1.0, 2.0, 3.0, 4.0], @@ -170,6 +171,19 @@ class TestIndexSelectAPI(unittest.TestCase): expect_out = np.repeat(self.data_x, repeats, axis=0) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + # case 3 zero_dim: + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[]) + z = paddle.repeat_interleave(x, repeats) + exe = fluid.Executor(fluid.CPUPlace()) + (res,) = exe.run( + feed={'x': self.data_zero_dim_x}, + fetch_list=[z.name], + return_numpy=False, + ) + expect_out = np.repeat(self.data_zero_dim_x, repeats) + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + def test_dygraph_api(self): self.input_data() # case axis none @@ -220,6 +234,15 @@ class TestIndexSelectAPI(unittest.TestCase): expect_out = np.repeat(self.data_x, index, axis=0) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case 3 zero_dim: + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(self.data_zero_dim_x) + index = 2 + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(self.data_zero_dim_x, index, axis=None) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 5b18abdbc5..6a99725ca7 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -966,6 +966,34 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) + def test_repeat_interleave(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + + x = paddle.randn(()) + x.stop_gradient = False + + out = paddle.repeat_interleave(x, 2, None) + out.backward() + + # check shape of output + self.assertEqual(out.shape, [2]) + + # check grad shape + self.assertEqual(x.grad.shape, []) + + repeats = paddle.to_tensor([3], dtype='int32') + out = paddle.repeat_interleave(x, repeats, None) + + # check shape of output with 1D repeats + self.assertEqual(out.shape, [3]) + + # check grad shape with 1D repeats + self.assertEqual(x.grad.shape, []) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -1380,6 +1408,24 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) + @prog_scope() + def test_repeat_interleave(self): + x = paddle.full([], 1.0, 'float32') + out = paddle.repeat_interleave(x, 2, None) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, (2,)) + + repeats = paddle.to_tensor([3], dtype='int32') + out = paddle.repeat_interleave(x, repeats, None) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, (3,)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): -- GitLab