From cb22a5c7524bc768ef90e74ac9b3e8f55097b74b Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Sat, 31 Dec 2022 00:38:47 +0800 Subject: [PATCH] support flip 0D (#49460) --- paddle/phi/kernels/gpu/flip_kernel.cu | 3 +++ .../tests/unittests/test_zero_dim_tensor.py | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu index 6e9dbf37a9..7945d6c8fc 100644 --- a/paddle/phi/kernels/gpu/flip_kernel.cu +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -101,6 +101,9 @@ void FlipKernel(const Context& dev_ctx, DenseTensor* out) { const size_t total_dims = x.dims().size(); switch (total_dims) { + case 0: + LaunchFlipCudaKernel(dev_ctx, x, axis, out); + break; case 1: LaunchFlipCudaKernel(dev_ctx, x, axis, out); break; diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index b8a1151048..887a04f10c 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -454,6 +454,15 @@ class TestSundryAPI(unittest.TestCase): paddle.disable_static() self.x = paddle.rand([]) + def test_flip(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.flip(x, axis=[]) + out.backward() + self.assertEqual(x.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + def test_linear(self): x = paddle.randn([3, 2]) w = paddle.full(shape=[2, 4], fill_value=0.5) @@ -753,6 +762,18 @@ class TestSundryAPIStatic(unittest.TestCase): paddle.enable_static() self.exe = paddle.static.Executor() + @prog_scope() + def test_flip(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.flip(x, axis=[]) + paddle.static.append_backward(out) + + program = paddle.static.default_main_program() + res1, res2 = self.exe.run(program, fetch_list=[x, out]) + self.assertEqual(res1.shape, ()) + self.assertEqual(res2.shape, ()) + @prog_scope() def test_pow_factor(self): x = paddle.rand([]) -- GitLab