From 336bc20b18cc3f30dea0819ee99ddc47109a4282 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Tue, 25 Apr 2023 13:44:29 +0800 Subject: [PATCH] tile op support 0D input for xpu (#53237) --- paddle/phi/kernels/xpu/tile_kernel.cc | 41 +++++++++++------------- test/xpu/test_tile_op_xpu.py | 46 +++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/xpu/tile_kernel.cc b/paddle/phi/kernels/xpu/tile_kernel.cc index 419ff72e640..f6bc716a7d5 100644 --- a/paddle/phi/kernels/xpu/tile_kernel.cc +++ b/paddle/phi/kernels/xpu/tile_kernel.cc @@ -31,13 +31,15 @@ void TileKernel(const Context& dev_ctx, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; auto rank = x.dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1, - errors::InvalidArgument( - "The rank of the input 'x' for tile op must be a positive " - "integer, but the value received is %d.", - rank)); + std::vector repeat_times = repeat_times_arr.GetData(); + int repeat_times_size = repeat_times.size(); + rank = std::max(rank, repeat_times_size); + PADDLE_ENFORCE_GE(rank, + 0, + errors::InvalidArgument( + "The rank of the input 'x' for tile op must be a >=0 " + "integer, but the value received is %d.", + rank)); PADDLE_ENFORCE_LE( rank, MAX_RANK_SUPPORTED, @@ -46,14 +48,12 @@ void TileKernel(const Context& dev_ctx, "must be less than or equal to %d, but the value received is %d.", MAX_RANK_SUPPORTED, rank)); - std::vector repeat_times = repeat_times_arr.GetData(); - int repeat_times_size = repeat_times.size(); PADDLE_ENFORCE_GE( repeat_times_size, - 1, + 0, errors::InvalidArgument( "The number of elements of the input 'repeat_times' for tile " - "op must be positive, but the value received is %d.", + "op must be >=0, but the value received is %d.", repeat_times_size)); PADDLE_ENFORCE_LE( repeat_times_size, @@ -102,20 +102,15 @@ void TileKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); std::vector temp(repeat_times.size(), 1); - if (repeat_times == temp) { + if (rank == 0 || repeat_times == temp) { out->Resize(x.dims()); dev_ctx.template Alloc(out); - if (std::is_same::value) { - int r = xpu::copy(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out->data()), - 8 * x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - } else { - int r = xpu::copy( - dev_ctx.x_context(), x.data(), out->data(), x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - } + int64_t count = x.numel() * sizeof(T); + int r = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + count); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); return; } diff --git a/test/xpu/test_tile_op_xpu.py b/test/xpu/test_tile_op_xpu.py index dc2b0d7f0ed..2e661199a09 100644 --- a/test/xpu/test_tile_op_xpu.py +++ b/test/xpu/test_tile_op_xpu.py @@ -90,6 +90,21 @@ class XPUTestTileOpRank1(XPUOpTestWrapper): self.ori_shape = (2, 4, 5, 7) self.repeat_times = (3, 2, 1, 2) + class TestTileOpRank_ZeroDim1(TestTileOpRank1): + def init_data(self): + self.ori_shape = [] + self.repeat_times = [] + + class TestTileOpRank_ZeroDim2(TestTileOpRank1): + def init_data(self): + self.ori_shape = [] + self.repeat_times = [2] + + class TestTileOpRank_ZeroDim3(TestTileOpRank1): + def init_data(self): + self.ori_shape = [] + self.repeat_times = [2, 3] + # Situation 2: repeat_times is a list (with tensor) class XPUTestTileOpRank1_tensor_attr(XPUOpTestWrapper): @@ -209,5 +224,36 @@ class TestTileAPI(unittest.TestCase): assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3))) +class TestTileAPI_ZeroDim(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + + x = paddle.rand([]) + x.stop_gradient = False + + out = paddle.tile(x, []) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, []) + + out = paddle.tile(x, [3]) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, [3]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [3]) + + out = paddle.tile(x, [2, 3]) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, [2, 3]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, [2, 3]) + + paddle.enable_static() + + if __name__ == "__main__": unittest.main() -- GitLab