diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 371cc7d39c2900c55f12d1508a2e277fa4b5db7f..00c679f0ab999e4539d4fc8886ddec1cb475c301 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -29,6 +30,7 @@ void ArgsortGradKernel(const Context& dev_ctx, bool descending, DenseTensor* in_grad) { auto in_dims = indices.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; dev_ctx.template Alloc(in_grad); @@ -40,6 +42,11 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + return; + } + bool is_need_transpose = true; if (axis == -1 || axis + 1 == in_dims.size()) { is_need_transpose = false; diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 0a71ec71463d4173f29dfd0bde7e9255e465f71e..4fdb42f69fd87779418abf4fc9ef23169392704e 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -171,6 +172,7 @@ void ArgsortKernel(const Context& dev_ctx, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; int n = in_dims[axis]; @@ -178,6 +180,12 @@ void ArgsortKernel(const Context& dev_ctx, auto output_data = dev_ctx.template Alloc(output); auto indices_data = dev_ctx.template Alloc(indices); + if (rank == 0) { + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_after = phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));