From e6254e35563db38edad5bc3e51054257c35062ca Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 3 Jan 2023 11:25:43 +0000 Subject: [PATCH] support 0D tensor for paddle.sort/argsort in xpu --- paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 7 +++++++ paddle/phi/kernels/xpu/argsort_kernel.cc | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 371cc7d39c..00c679f0ab 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -29,6 +30,7 @@ void ArgsortGradKernel(const Context& dev_ctx, bool descending, DenseTensor* in_grad) { auto in_dims = indices.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; dev_ctx.template Alloc(in_grad); @@ -40,6 +42,11 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + return; + } + bool is_need_transpose = true; if (axis == -1 || axis + 1 == in_dims.size()) { is_need_transpose = false; diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 0a71ec7146..4fdb42f69f 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -171,6 +172,7 @@ void ArgsortKernel(const Context& dev_ctx, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; int n = in_dims[axis]; @@ -178,6 +180,12 @@ void ArgsortKernel(const Context& dev_ctx, auto output_data = dev_ctx.template Alloc(output); auto indices_data = dev_ctx.template Alloc(indices); + if (rank == 0) { + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_after = phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); -- GitLab