From 6a6a3ff1ea4fd5e059dab6edacd4b8287ea98e90 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Tue, 8 Nov 2022 16:58:09 +0800 Subject: [PATCH] argsort support n > 16384 and add argsort_grad op for xpu, test=kunlun (#47701) --- .../fluid/platform/device/xpu/xpu2_op_list.h | 4 + paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 110 ++++++++ paddle/phi/kernels/xpu/argsort_kernel.cc | 240 ++++-------------- .../unittests/xpu/test_argsort_op_xpu.py | 3 + 4 files changed, 165 insertions(+), 192 deletions(-) create mode 100644 paddle/phi/kernels/xpu/argsort_grad_kernel.cc diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index f9e66631c1..79ee392405 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -38,6 +38,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"argsort_grad", + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"argsort", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc new file mode 100644 index 0000000000..371cc7d39c --- /dev/null +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ArgsortGradKernel(const Context& dev_ctx, + const DenseTensor& indices, + const DenseTensor& input, + const DenseTensor& out_grad, + int axis, + bool descending, + DenseTensor* in_grad) { + auto in_dims = indices.dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + dev_ctx.template Alloc(in_grad); + + int r = xpu::constant(dev_ctx.x_context(), + in_grad->data(), + in_grad->numel(), + static_cast(0.0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + + if (out_grad.numel() == 0) return; + + bool is_need_transpose = true; + if (axis == -1 || axis + 1 == in_dims.size()) { + is_need_transpose = false; + } + int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); + int len_after = + phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); + int m = len_before * len_after; + int n = in_dims[axis]; + int len = m * n; + std::vector permute_vec{0, 2, 1}; + std::vector data_shape{len_before, n, len_after}; + std::vector data_shape_trans{len_before, len_after, n}; + + const int64_t* indices_data = indices.data(); + const T* out_grad_data = out_grad.data(); + T* in_grad_data = in_grad->data(); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if (is_need_transpose) { + int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans); + T* out_grad_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(out_grad_data_trans); + T* in_grad_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(in_grad_data_trans); + + r = xpu::transpose(dev_ctx.x_context(), + indices_data, + indices_data_trans, + data_shape, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + r = xpu::transpose(dev_ctx.x_context(), + out_grad_data, + out_grad_data_trans, + data_shape, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + indices_data = indices_data_trans; + out_grad_data = out_grad_data_trans; + in_grad_data = in_grad_data_trans; + } + + r = xpu::sort_grad( + dev_ctx.x_context(), out_grad_data, indices_data, in_grad_data, m, n); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sort_grad"); + + if (is_need_transpose) { + r = xpu::transpose(dev_ctx.x_context(), + in_grad_data, + in_grad->data(), + data_shape_trans, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + } +} + +} // namespace phi +PD_REGISTER_KERNEL(argsort_grad, + XPU, + ALL_LAYOUT, + phi::ArgsortGradKernel, + float, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 80db142e15..9a1cdd763b 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -14,171 +14,12 @@ #include "paddle/phi/kernels/argsort_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { -const int XPU_SORT_MAX_SIZE = 16384; - -template -static inline void xpu_argsort(xpu::Context* ctx, - const T* input_data, - T* output_data, - TID* indices_data, - int m, - int n, - bool descending) { - int ret = - xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - errors::External("XPU sort kernel return wrong value[%d %s].", - ret, - XPUAPIErrorMsg[ret])); -} - -template -static inline void xpu_transpose(xpu::Context* ctx, - const T* x, - T* y, - const std::vector& xshape, - const std::vector& permute) { - int ret = xpu::transpose(ctx, x, y, xshape, permute); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - errors::External("XPU transpose kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); -} - -template -static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { - int ret = xpu::cast_v2(ctx, x, y, len); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - errors::External("XPU cast kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); -} - -template -struct XPUArgsort { - void operator()(xpu::Context* ctx, - const T* input_data, - T* output_data, - int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, - bool descending) { - xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ - data_shape[0], data_shape[2], data_shape[1]}; - - T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - - xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); - xpu_argsort(ctx, - input_data_trans, - output_data_trans, - indices_data_trans, - m, - n, - descending); - xpu_transpose( - ctx, output_data_trans, output_data, trans_data_shape, permute); - xpu_transpose( - ctx, indices_data_trans, indices_data, trans_data_shape, permute); - } -}; - -template -struct XPUArgsort { - void operator()(xpu::Context* ctx, - const T* input_data, - T* output_data, - int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, - bool descending) { - xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ - data_shape[0], data_shape[2], data_shape[1]}; - - T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); - - xpu_transpose(ctx, input_data, input_data_trans, data_shape, permute); - xpu_argsort(ctx, - input_data_trans, - output_data_trans, - indices_data_trans, - m, - n, - descending); - xpu_transpose( - ctx, output_data_trans, output_data, trans_data_shape, permute); - xpu_cast(ctx, indices_data_trans, cast_data_int64, len); - xpu_transpose( - ctx, cast_data_int64, indices_data, trans_data_shape, permute); - } -}; - -template <> -struct XPUArgsort { - void operator()(xpu::Context* ctx, - const int64_t* input_data, - int64_t* output_data, - int64_t* indices_data, - const std::vector& data_shape, - const std::vector& permute, - bool descending) { - xpu::ctx_guard RAII_GUARD(ctx); - int m = data_shape[0] * data_shape[2]; - int n = data_shape[1]; - int len = data_shape[0] * data_shape[1] * data_shape[2]; - std::vector trans_data_shape{ - data_shape[0], data_shape[2], data_shape[1]}; - - int* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); - int* cast_data_int = RAII_GUARD.alloc_l3_or_gm(len); - int64_t* cast_data_int64 = RAII_GUARD.alloc_l3_or_gm(len); - - xpu_cast(ctx, input_data, cast_data_int, len); - xpu_transpose(ctx, cast_data_int, input_data_trans, data_shape, permute); - xpu_argsort(ctx, - input_data_trans, - output_data_trans, - indices_data_trans, - m, - n, - descending); - - xpu_cast(ctx, output_data_trans, cast_data_int64, len); - xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute); - xpu_cast(ctx, indices_data_trans, cast_data_int64, len); - xpu_transpose( - ctx, cast_data_int64, indices_data, trans_data_shape, permute); - } -}; - template void ArgsortKernel(const Context& dev_ctx, const DenseTensor& input, @@ -190,52 +31,67 @@ void ArgsortKernel(const Context& dev_ctx, axis = (axis < 0) ? (in_dims.size() + axis) : axis; int n = in_dims[axis]; - PADDLE_ENFORCE_LT( - n, - XPU_SORT_MAX_SIZE, - errors::InvalidArgument( - "The axis dimension of Input should less than %d, but got %d.", - XPU_SORT_MAX_SIZE, - in_dims[axis])); - auto input_data = input.data(); auto output_data = dev_ctx.template Alloc(output); auto indices_data = dev_ctx.template Alloc(indices); + bool is_need_transpose = true; + if (axis == -1 || axis + 1 == in_dims.size()) { + is_need_transpose = false; + } int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_after = phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); - bool int64_need_cast = - (std::is_same::value && n > (XPU_SORT_MAX_SIZE / 2)) ? true - : false; - bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false; + int m = len_before * len_after; + int len = m * n; std::vector permute_vec{0, 2, 1}; std::vector data_shape{len_before, n, len_after}; + std::vector data_shape_trans{len_before, len_after, n}; - if (int64_need_cast) { - XPUArgsort()(dev_ctx.x_context(), - input_data, - output_data, - indices_data, - data_shape, - permute_vec, - descending); - } else if (index_need_cast) { - XPUArgsort()(dev_ctx.x_context(), - input_data, - output_data, - indices_data, - data_shape, - permute_vec, - descending); - } else { - XPUArgsort()(dev_ctx.x_context(), + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if (is_need_transpose) { + T* input_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(input_data_trans); + T* output_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(output_data_trans); + int64_t* indices_data_trans = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(indices_data_trans); + + int r = xpu::transpose(dev_ctx.x_context(), + input_data, + input_data_trans, + data_shape, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + input_data = input_data_trans; + output_data = output_data_trans; + indices_data = indices_data_trans; + } + + int ret = xpu::sort(dev_ctx.x_context(), input_data, output_data, indices_data, - data_shape, - permute_vec, + m, + n, descending); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort"); + + if (is_need_transpose) { + int r = xpu::transpose(dev_ctx.x_context(), + output_data, + output->data(), + data_shape_trans, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + r = xpu::transpose(dev_ctx.x_context(), + indices_data, + indices->data(), + data_shape_trans, + permute_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } } diff --git a/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py index 8ee7716447..c16d0fdb5e 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_argsort_op_xpu.py @@ -94,6 +94,9 @@ class XPUTestArgsortOp(XPUOpTestWrapper): def test_check_output(self): self.check_output_with_place(self.place) + def test_check_grad(self): + self.check_grad_with_place(self.place, {'X'}, 'Out') + support_types = get_xpu_op_support_types('argsort') for stype in support_types: -- GitLab