未验证 提交 6a6a3ff1 编写于 作者: Z zhangyikun02 提交者: GitHub

argsort support n > 16384 and add argsort_grad op for xpu, test=kunlun (#47701)

上级 793c35ef
......@@ -38,6 +38,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"argsort_grad",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"argsort",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/argsort_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const DenseTensor& input,
const DenseTensor& out_grad,
int axis,
bool descending,
DenseTensor* in_grad) {
auto in_dims = indices.dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
dev_ctx.template Alloc<T>(in_grad);
int r = xpu::constant<T>(dev_ctx.x_context(),
in_grad->data<T>(),
in_grad->numel(),
static_cast<T>(0.0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
if (out_grad.numel() == 0) return;
bool is_need_transpose = true;
if (axis == -1 || axis + 1 == in_dims.size()) {
is_need_transpose = false;
}
int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis));
int len_after =
phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));
int m = len_before * len_after;
int n = in_dims[axis];
int len = m * n;
std::vector<int> permute_vec{0, 2, 1};
std::vector<int> data_shape{len_before, n, len_after};
std::vector<int> data_shape_trans{len_before, len_after, n};
const int64_t* indices_data = indices.data<int64_t>();
const T* out_grad_data = out_grad.data<T>();
T* in_grad_data = in_grad->data<T>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
if (is_need_transpose) {
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans);
T* out_grad_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(out_grad_data_trans);
T* in_grad_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(in_grad_data_trans);
r = xpu::transpose<int64_t>(dev_ctx.x_context(),
indices_data,
indices_data_trans,
data_shape,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::transpose<T>(dev_ctx.x_context(),
out_grad_data,
out_grad_data_trans,
data_shape,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
indices_data = indices_data_trans;
out_grad_data = out_grad_data_trans;
in_grad_data = in_grad_data_trans;
}
r = xpu::sort_grad<T, int64_t>(
dev_ctx.x_context(), out_grad_data, indices_data, in_grad_data, m, n);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sort_grad");
if (is_need_transpose) {
r = xpu::transpose<T>(dev_ctx.x_context(),
in_grad_data,
in_grad->data<T>(),
data_shape_trans,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
}
} // namespace phi
PD_REGISTER_KERNEL(argsort_grad,
XPU,
ALL_LAYOUT,
phi::ArgsortGradKernel,
float,
int,
int64_t) {}
......@@ -14,171 +14,12 @@
#include "paddle/phi/kernels/argsort_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
const int XPU_SORT_MAX_SIZE = 16384;
template <typename T, typename TID>
static inline void xpu_argsort(xpu::Context* ctx,
const T* input_data,
T* output_data,
TID* indices_data,
int m,
int n,
bool descending) {
int ret =
xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending);
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
errors::External("XPU sort kernel return wrong value[%d %s].",
ret,
XPUAPIErrorMsg[ret]));
}
template <typename T>
static inline void xpu_transpose(xpu::Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& permute) {
int ret = xpu::transpose(ctx, x, y, xshape, permute);
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
errors::External("XPU transpose kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
}
template <typename TX, typename TY>
static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) {
int ret = xpu::cast_v2(ctx, x, y, len);
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
errors::External("XPU cast kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
}
template <typename T,
bool VALUE_NEED_CAST = false,
bool INDEX_NEED_CAST = false>
struct XPUArgsort {
void operator()(xpu::Context* ctx,
const T* input_data,
T* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_transpose(
ctx, indices_data_trans, indices_data, trans_data_shape, permute);
}
};
template <typename T>
struct XPUArgsort<T, false, true> {
void operator()(xpu::Context* ctx,
const T* input_data,
T* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(
ctx, cast_data_int64, indices_data, trans_data_shape, permute);
}
};
template <>
struct XPUArgsort<int64_t, true, true> {
void operator()(xpu::Context* ctx,
const int64_t* input_data,
int64_t* output_data,
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
int len = data_shape[0] * data_shape[1] * data_shape[2];
std::vector<int> trans_data_shape{
data_shape[0], data_shape[2], data_shape[1]};
int* input_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* output_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int>(len);
int* cast_data_int = RAII_GUARD.alloc_l3_or_gm<int>(len);
int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
xpu_cast(ctx, input_data, cast_data_int, len);
xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute);
xpu_argsort(ctx,
input_data_trans,
output_data_trans,
indices_data_trans,
m,
n,
descending);
xpu_cast(ctx, output_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
xpu_transpose(
ctx, cast_data_int64, indices_data, trans_data_shape, permute);
}
};
template <typename T, typename Context>
void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
......@@ -190,52 +31,67 @@ void ArgsortKernel(const Context& dev_ctx,
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int n = in_dims[axis];
PADDLE_ENFORCE_LT(
n,
XPU_SORT_MAX_SIZE,
errors::InvalidArgument(
"The axis dimension of Input should less than %d, but got %d.",
XPU_SORT_MAX_SIZE,
in_dims[axis]));
auto input_data = input.data<T>();
auto output_data = dev_ctx.template Alloc<T>(output);
auto indices_data = dev_ctx.template Alloc<int64_t>(indices);
bool is_need_transpose = true;
if (axis == -1 || axis + 1 == in_dims.size()) {
is_need_transpose = false;
}
int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis));
int len_after =
phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));
bool int64_need_cast =
(std::is_same<T, int64_t>::value && n > (XPU_SORT_MAX_SIZE / 2)) ? true
: false;
bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false;
int m = len_before * len_after;
int len = m * n;
std::vector<int> permute_vec{0, 2, 1};
std::vector<int> data_shape{len_before, n, len_after};
std::vector<int> data_shape_trans{len_before, len_after, n};
if (int64_need_cast) {
XPUArgsort<T, true, true>()(dev_ctx.x_context(),
input_data,
output_data,
indices_data,
data_shape,
permute_vec,
descending);
} else if (index_need_cast) {
XPUArgsort<T, false, true>()(dev_ctx.x_context(),
input_data,
output_data,
indices_data,
data_shape,
permute_vec,
descending);
} else {
XPUArgsort<T, false, false>()(dev_ctx.x_context(),
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
if (is_need_transpose) {
T* input_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(input_data_trans);
T* output_data_trans = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(output_data_trans);
int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm<int64_t>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans);
int r = xpu::transpose<T>(dev_ctx.x_context(),
input_data,
input_data_trans,
data_shape,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
input_data = input_data_trans;
output_data = output_data_trans;
indices_data = indices_data_trans;
}
int ret = xpu::sort<T, int64_t>(dev_ctx.x_context(),
input_data,
output_data,
indices_data,
data_shape,
permute_vec,
m,
n,
descending);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort");
if (is_need_transpose) {
int r = xpu::transpose<T>(dev_ctx.x_context(),
output_data,
output->data<T>(),
data_shape_trans,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::transpose<int64_t>(dev_ctx.x_context(),
indices_data,
indices->data<int64_t>(),
data_shape_trans,
permute_vec);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
}
......
......@@ -94,6 +94,9 @@ class XPUTestArgsortOp(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, {'X'}, 'Out')
support_types = get_xpu_op_support_types('argsort')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册