提交 e6254e35 编写于 作者: D DesmonDay

support 0D tensor for paddle.sort/argsort in xpu

上级 9443bc19
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -29,6 +30,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
bool descending,
DenseTensor* in_grad) {
auto in_dims = indices.dims();
auto rank = in_dims.size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
dev_ctx.template Alloc<T>(in_grad);
......@@ -40,6 +42,11 @@ void ArgsortGradKernel(const Context& dev_ctx,
if (out_grad.numel() == 0) return;
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0);
return;
}
bool is_need_transpose = true;
if (axis == -1 || axis + 1 == in_dims.size()) {
is_need_transpose = false;
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -171,6 +172,7 @@ void ArgsortKernel(const Context& dev_ctx,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
auto rank = in_dims.size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int n = in_dims[axis];
......@@ -178,6 +180,12 @@ void ArgsortKernel(const Context& dev_ctx,
auto output_data = dev_ctx.template Alloc<T>(output);
auto indices_data = dev_ctx.template Alloc<int64_t>(indices);
if (rank == 0) {
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
phi::funcs::set_constant(dev_ctx, indices, 0);
return;
}
int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis));
int len_after =
phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册