From 9262a93c2ef51db2f0b80e6539c5b3e54f2fc6f4 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Thu, 10 Mar 2022 10:53:42 +0800 Subject: [PATCH] [phi] move argsort to phi (#40151) * move argsort to phi * refine files * remove mutable_data * fix unittest * fix unittest * remove infershape * update infershape * fix ci * fix ci * fix ci * fix * fix * fix * fix * fix * fix * fix --- paddle/fluid/operators/argsort_op.cc | 46 +- paddle/fluid/operators/argsort_op.cu | 430 ------------------ paddle/fluid/operators/argsort_op.h | 243 ---------- paddle/fluid/operators/argsort_op_npu.cc | 2 +- paddle/fluid/operators/argsort_op_xpu.cc | 2 +- paddle/phi/infermeta/unary.cc | 28 ++ paddle/phi/infermeta/unary.h | 6 + paddle/phi/kernels/argsort_grad_kernel.h | 30 ++ paddle/phi/kernels/argsort_kernel.h | 29 ++ paddle/phi/kernels/cpu/argsort_grad_kernel.cc | 133 ++++++ paddle/phi/kernels/cpu/argsort_kernel.cc | 143 ++++++ paddle/phi/kernels/gpu/argsort_grad_kernel.cu | 217 +++++++++ paddle/phi/kernels/gpu/argsort_kernel.cu | 310 +++++++++++++ paddle/phi/ops/compat/argsort_sig.cc | 29 ++ 14 files changed, 936 insertions(+), 712 deletions(-) delete mode 100644 paddle/fluid/operators/argsort_op.cu delete mode 100644 paddle/fluid/operators/argsort_op.h create mode 100644 paddle/phi/kernels/argsort_grad_kernel.h create mode 100644 paddle/phi/kernels/argsort_kernel.h create mode 100644 paddle/phi/kernels/cpu/argsort_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/argsort_kernel.cc create mode 100644 paddle/phi/kernels/gpu/argsort_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/argsort_kernel.cu create mode 100644 paddle/phi/ops/compat/argsort_sig.cc diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 9e525c2033..1a8aca7773 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -12,40 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/argsort_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class ArgsortOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "argsort"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "argsort"); - OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "argsort"); - - auto in_dims = ctx->GetInputDim("X"); - int axis = ctx->Attrs().Get("axis"); - - auto num_dims = in_dims.size(); - PADDLE_ENFORCE_GE(axis, -num_dims, - platform::errors::InvalidArgument( - "'axis'(%d) must be greater than or equal to" - " -num_dims(%d).", - axis, -num_dims)); - PADDLE_ENFORCE_LT( - axis, num_dims, - platform::errors::InvalidArgument( - "'axis'(%d) must be less than num_dims(%d).", axis, num_dims)); - - ctx->ShareDim("X", "Out"); - ctx->ShareDim("X", "Indices"); - ctx->ShareLoD("X", "Out"); - ctx->ShareLoD("X", "Indices"); - } }; class ArgsortGradOp : public framework::OperatorWithKernel { @@ -122,18 +101,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ArgsortGradNoNeedBufferVarsInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(argsort, ArgsortInferShapeFunctor, + PD_INFER_META(phi::ArgsortInferMeta)); REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, ops::ArgsortGradOpMaker, - ops::ArgsortGradOpMaker); + ops::ArgsortGradOpMaker, + ArgsortInferShapeFunctor); REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp, ops::ArgsortGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(argsort, - ops::ArgsortKernel, - ops::ArgsortKernel, - ops::ArgsortKernel, - ops::ArgsortKernel); -REGISTER_OP_CPU_KERNEL( - argsort_grad, ops::ArgsortGradientKernel, - ops::ArgsortGradientKernel, - ops::ArgsortGradientKernel, - ops::ArgsortGradientKernel); diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu deleted file mode 100644 index 8b7a0b3ead..0000000000 --- a/paddle/fluid/operators/argsort_op.cu +++ /dev/null @@ -1,430 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/argsort_op.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -#ifdef __HIPCC__ -namespace rocprim { -namespace detail { -template <> -struct radix_key_codec_base - : radix_key_codec_integral {}; -} // namespace detail -} // namespace rocprim -#else -// set cub base traits in order to handle float16 -namespace cub { -template <> -struct NumericTraits - : BaseTraits {}; -} // namespace cub -#endif - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -// Iter for move to next row -struct SegmentOffsetIter { - EIGEN_DEVICE_FUNC - explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { - return idx * num_cols_; - } - - int num_cols_; -}; - -template -static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { - int col_id = threadIdx.x; - int row_id = blockIdx.x; - - for (T j = row_id; j < num_rows; j += gridDim.x) { - for (T i = col_id; i < num_cols; i += blockDim.x) { - indices[j * num_cols + i] = i; - } - } -} - -template -static __global__ void FillFlattenGrad(const T* dO, const IndType* indices, - int64_t size, T* dX) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - int stride = blockDim.x * gridDim.x; - for (int i = index; i < size; i += stride) { - dX[indices[i]] = dO[i]; - } -} - -template -static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX, - IndType num_rows, IndType num_cols) { - int col_id = threadIdx.x; - int row_id = blockIdx.x; - - for (IndType j = row_id; j < num_rows; j += gridDim.x) { - for (IndType i = col_id; i < num_cols; i += blockDim.x) { - dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i]; - } - } -} - -// Sort by flag descending, True: descending. False: Ascending. -// Default is false. -template -void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, - Tensor* output, Tensor* indices, const IndType num_rows, - const IndType num_cols, const bool descending) { - auto cu_stream = ctx.stream(); - - Tensor input_indices; - - const std::vector dims = {num_rows, num_cols}; - auto dim = phi::make_ddim(dims); - input_indices.Resize(dim); - input_indices.mutable_data(ctx.GetPlace()); - - size_t temp_storage_bytes = -1; - - auto ComputeBlockSize = [](IndType col) { - if (col > 512) - return 1024; - else if (col > 256 && col <= 512) - return 512; - else if (col > 128 && col <= 256) - return 256; - else if (col > 64 && col <= 128) - return 128; - else - return 64; - }; - - int block_size = ComputeBlockSize(num_cols); - - int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - // actually, int num_rows < max_grid_size - int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; - // Init a index array - FillIndex<<>>( - input_indices.data(), num_rows, num_cols); - - T* sorted_out_ptr; - IndType* sorted_indices_ptr; - - const T* inp = input->data(); - T* out = output->mutable_data(ctx.GetPlace()); - IndType* ind = indices->mutable_data(ctx.GetPlace()); - - sorted_out_ptr = out; - sorted_indices_ptr = ind; - - // create iter for counting input - cub::CountingInputIterator counting_iter(0); - // segment_offset is used for move to next row - cub::TransformInputIterator> - segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); - - gpuError_t err; - if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - nullptr, temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); - } else { - err = cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); - } - PADDLE_ENFORCE_GPU_SUCCESS(err); - - Tensor temp_storage; - temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); - - if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); - } else { - err = cub::DeviceSegmentedRadixSort::SortPairs( - temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); - } - - PADDLE_ENFORCE_GPU_SUCCESS(err); -} - -template -void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO, - const Tensor* indices, Tensor* dX, const IndType num_rows, - const IndType num_cols) { - auto cu_stream = ctx.stream(); - - auto ComputeBlockSize = [](IndType col) { - if (col > 512) - return 1024; - else if (col > 256 && col <= 512) - return 512; - else if (col > 128 && col <= 256) - return 256; - else if (col > 64 && col <= 128) - return 128; - else - return 64; - }; - - int block_size = ComputeBlockSize(num_cols); - - int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - // actually, int num_rows < max_grid_size - int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; - FillGrad<<>>( - dO->data(), indices->data(), dX->data(), num_rows, - num_cols); -} - -template -void ArgFlattenAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO, - const Tensor* indices, int64_t size, Tensor* dX) { - auto cu_stream = ctx.stream(); - - const int64_t block_size = - std::min(size, static_cast(ctx.GetMaxThreadsPerBlock())); - int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); - const int64_t max_blocks = - std::max(((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (size + block_size - 1) / block_size); - - FillFlattenGrad<<>>( - dO->data(), indices->data(), size, dX->data()); -} - -template -class ArgsortOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* indices = ctx.Output("Indices"); - int axis = ctx.Attr("axis"); - bool descending = ctx.Attr("descending"); - - auto in_dims = input->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - - const T* in_data = input->data(); - auto size = input->numel(); - T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - - // Use thrust for parallel acceleration when the input size is equal to the - // length of the ‘axis’ dimension. - // Compared to the following 'Special case for full sort', ascending sort is - // 34 times faster and descending sort is 31 times faster. - if (size == in_dims[axis]) { - thrust::sequence(thrust::device, ids_data, ids_data + size); - thrust::copy(thrust::device, in_data, in_data + size, out_data); - thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data); - if (descending) { - thrust::reverse(thrust::device, out_data, out_data + size); - thrust::reverse(thrust::device, ids_data, ids_data + size); - } - return; - } - - // Special case for full sort, speedup ~190x. - if (axis == -1 || axis + 1 == in_dims.size()) { - const int64_t input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t input_width = in_dims[in_dims.size() - 1]; - const auto& dev_ctx = ctx.cuda_device_context(); - ArgFullSort(dev_ctx, input, output, indices, input_height, - input_width, descending); - } else { - // if not full sort, do transpose first - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.push_back(i); - } - trans.push_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.push_back(i); - } - trans.push_back(axis); - framework::DDim trans_dims(in_dims); - for (int i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - } - - Tensor trans_inp; - T* trans_inp_data = trans_inp.mutable_data(trans_dims, ctx.GetPlace()); - int ndims = trans.size(); - const auto& dev_ctx = ctx.cuda_device_context(); - // Do transpose - TransCompute(ndims, dev_ctx, *input, - &trans_inp, trans); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - Tensor tmp_out; - tmp_out.mutable_data(trans_dims, ctx.GetPlace()); - T* out_data = output->mutable_data(ctx.GetPlace()); - - Tensor tmp_indices; - // temp indices for sorting - tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); - indices->mutable_data(ctx.GetPlace()); - - ArgFullSort(dev_ctx, &trans_inp, &tmp_out, &tmp_indices, - input_height, input_width, descending); - - TransCompute( - ndims, dev_ctx, tmp_indices, indices, trans); - // transpose back - TransCompute(ndims, dev_ctx, tmp_out, - output, trans); - return; - } - } -}; - -template -class ArgsortGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* indices = ctx.Input("Indices"); - auto* dX = ctx.Output(framework::GradVarName("X")); - auto* dO = ctx.Input(framework::GradVarName("Out")); - int axis = ctx.Attr("axis"); - - dX->mutable_data(ctx.GetPlace()); - if (dO->numel() == 0) return; - - auto in_dims = dX->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - - int64_t size = dX->numel(); - const auto& dev_ctx = ctx.cuda_device_context(); - - // Parallel acceleration when the input size is equal to the length of the - // ‘axis’ dimension. - // Compared to 'special case for full sort' below, the gradient calculation - // is 10 times faster. - if (size == in_dims[axis]) { - ArgFlattenAssign(dev_ctx, dO, indices, size, dX); - return; - } - - // Special case for full sort, speedup ~190x. - if (axis == -1 || axis + 1 == in_dims.size()) { - const int64_t input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t input_width = in_dims[in_dims.size() - 1]; - ArgFullAssign(dev_ctx, dO, indices, dX, input_height, - input_width); - } else { - // if not full sort, do transpose first - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.push_back(i); - } - trans.push_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.push_back(i); - } - trans.push_back(axis); - framework::DDim trans_dims(in_dims); - for (int i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - } - - Tensor trans_dO; - trans_dO.mutable_data(trans_dims, ctx.GetPlace()); - Tensor trans_ind; - trans_ind.mutable_data(trans_dims, ctx.GetPlace()); - int ndims = trans.size(); - // Do transpose - TransCompute(ndims, dev_ctx, *dO, - &trans_dO, trans); - TransCompute( - ndims, dev_ctx, *indices, &trans_ind, trans); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - Tensor tmp_out; - tmp_out.mutable_data(trans_dims, ctx.GetPlace()); - - ArgFullAssign(dev_ctx, &trans_dO, &trans_ind, &tmp_out, - input_height, input_width); - - // transpose back - TransCompute(ndims, dev_ctx, tmp_out, dX, - trans); - return; - } - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - argsort, - paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel, - paddle::operators::ArgsortGradOpCUDAKernel, - paddle::operators::ArgsortGradOpCUDAKernel, - paddle::operators::ArgsortGradOpCUDAKernel, - paddle::operators::ArgsortGradOpCUDAKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h deleted file mode 100644 index d850e51a4b..0000000000 --- a/paddle/fluid/operators/argsort_op.h +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/transpose_op.h" - -namespace paddle { -namespace operators { - -template -using EigenMatrix = framework::EigenMatrix; - -template -using EigenVector = framework::EigenVector; - -using Tensor = framework::Tensor; - -template -static void FullSort(Type input_height, Type input_width, int input_dim, - const framework::Tensor* input, T* t_out, Type* t_indices, - bool descending) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (Type i = 0; i < input_height; ++i) { - std::vector> col_vec; - col_vec.reserve(input_width); - if (input_dim == 1) { - auto e_input = EigenVector::Flatten(*input); - for (Type j = 0; j < input_width; ++j) { - col_vec.push_back(std::pair(e_input(j), j)); - } - } else { - auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); - for (Type j = 0; j < input_width; ++j) { - col_vec.push_back(std::pair(e_input(i, j), j)); - } - } - std::sort(col_vec.begin(), col_vec.end(), - [&](const std::pair& l, const std::pair& r) { - if (descending) - return l.first > r.first; - else - return l.first < r.first; - }); - - for (Type j = 0; j < input_width; ++j) { - t_out[i * input_width + j] = col_vec[j].first; - t_indices[i * input_width + j] = col_vec[j].second; - } - } -} - -template -static void FullAssign(Type input_height, Type input_width, int input_dim, - const framework::Tensor* input, - const framework::Tensor* indices, T* t_out) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (Type i = 0; i < input_height; ++i) { - if (input_dim == 1) { - auto e_input = EigenVector::Flatten(*input); - auto e_indices = EigenVector::Flatten(*indices); - for (Type j = 0; j < input_width; ++j) { - t_out[i * input_width + e_indices(j)] = e_input(j); - } - } else { - auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); - auto e_indices = EigenMatrix::Reshape(*indices, input_dim - 1); - for (Type j = 0; j < input_width; ++j) { - t_out[i * input_width + e_indices(i, j)] = e_input(i, j); - } - } - } -} - -template -class ArgsortKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* indices = ctx.Output("Indices"); - int axis = ctx.Attr("axis"); - bool descending = ctx.Attr("descending"); - - auto in_dims = input->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - - T* out_data = output->mutable_data(ctx.GetPlace()); - - // Do full sort - if (axis == -1 || axis + 1 == in_dims.size()) { - const int64_t input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t input_width = in_dims[in_dims.size() - 1]; - - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - FullSort(input_height, input_width, in_dims.size(), input, - out_data, ids_data, descending); - } else { - // If not full sort do transpose - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.push_back(i); - } - trans.push_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.push_back(i); - } - trans.push_back(axis); - framework::DDim trans_dims(in_dims); - for (size_t i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - } - - Tensor trans_inp; - trans_inp.mutable_data(trans_dims, ctx.GetPlace()); - int ndims = trans.size(); - auto& dev_ctx = ctx.template device_context(); - // Do transpose - TransCompute(ndims, dev_ctx, *input, - &trans_inp, trans); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - Tensor tmp_out; - T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); - output->mutable_data(ctx.GetPlace()); - - Tensor tmp_indices; - - auto* t_ind = - tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); - - FullSort(input_height, input_width, in_dims.size(), - &trans_inp, t_out, t_ind, descending); - - indices->mutable_data(ctx.GetPlace()); - TransCompute( - ndims, dev_ctx, tmp_indices, indices, trans); - // transpose back - TransCompute(ndims, dev_ctx, tmp_out, - output, trans); - } - } -}; - -template -class ArgsortGradientKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* indices = ctx.Input("Indices"); - auto* dX = ctx.Output(framework::GradVarName("X")); - auto* dO = ctx.Input(framework::GradVarName("Out")); - int axis = ctx.Attr("axis"); - - auto in_dims = indices->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - - dX->mutable_data(ctx.GetPlace()); - auto dxt = framework::EigenVector::Flatten(*dX); - auto& place = *ctx.template device_context() - .eigen_device(); - dxt.device(place) = dxt.constant(static_cast(0)); - if (dO->numel() == 0) return; - - // Do full assign - if (axis == -1 || axis + 1 == in_dims.size()) { - const int64_t input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t input_width = in_dims[in_dims.size() - 1]; - - FullAssign(input_height, input_width, in_dims.size(), dO, - indices, dX->data()); - } else { - // If not full assign do transpose - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.push_back(i); - } - trans.push_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.push_back(i); - } - trans.push_back(axis); - framework::DDim trans_dims(in_dims); - for (size_t i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - } - - Tensor trans_dO; - trans_dO.mutable_data(trans_dims, ctx.GetPlace()); - Tensor trans_ind; - trans_ind.mutable_data(trans_dims, ctx.GetPlace()); - int ndims = trans.size(); - auto& dev_ctx = ctx.template device_context(); - // Do transpose - TransCompute(ndims, dev_ctx, *dO, - &trans_dO, trans); - TransCompute( - ndims, dev_ctx, *indices, &trans_ind, trans); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - Tensor tmp_out; - T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); - - FullAssign(input_height, input_width, in_dims.size(), - &trans_dO, &trans_ind, t_out); - - // transpose back - TransCompute(ndims, dev_ctx, tmp_out, dX, - trans); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/argsort_op_npu.cc b/paddle/fluid/operators/argsort_op_npu.cc index 077be715be..c927eec00b 100644 --- a/paddle/fluid/operators/argsort_op_npu.cc +++ b/paddle/fluid/operators/argsort_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/argsort_op_xpu.cc b/paddle/fluid/operators/argsort_op_xpu.cc index 18e81936a1..359b00fcf8 100644 --- a/paddle/fluid/operators/argsort_op_xpu.cc +++ b/paddle/fluid/operators/argsort_op_xpu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8c2707e1d2..f7f8612632 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -26,6 +26,34 @@ limitations under the License. */ namespace phi { +void ArgsortInferMeta(const MetaTensor& input, + int axis, + bool descending, + MetaTensor* output, + MetaTensor* indices) { + auto in_dims = input.dims(); + auto num_dims = in_dims.size(); + PADDLE_ENFORCE_GE( + axis, + -num_dims, + phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" + " -num_dims(%d).", + axis, + -num_dims)); + PADDLE_ENFORCE_LT( + axis, + num_dims, + phi::errors::InvalidArgument( + "'axis'(%d) must be less than num_dims(%d).", axis, num_dims)); + + output->share_dims(input); + output->set_dtype(input.dtype()); + indices->share_dims(input); + indices->set_dtype(DataType::INT64); + output->share_lod(input); + indices->share_lod(input); +} + void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->share_meta(x); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index df9258644a..08d05db1e5 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -32,6 +32,12 @@ class MetaConfig; // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +void ArgsortInferMeta(const MetaTensor& input, + int axis, + bool descending, + MetaTensor* output, + MetaTensor* indices); + void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); // meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] diff --git a/paddle/phi/kernels/argsort_grad_kernel.h b/paddle/phi/kernels/argsort_grad_kernel.h new file mode 100644 index 0000000000..b91bd69911 --- /dev/null +++ b/paddle/phi/kernels/argsort_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ArgsortGradKernel(const Context& dev_ctx, + const DenseTensor& indices, + const DenseTensor& input, + const DenseTensor& out_grad, + int axis, + bool descending, + DenseTensor* in_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/argsort_kernel.h b/paddle/phi/kernels/argsort_kernel.h new file mode 100644 index 0000000000..683e8631d2 --- /dev/null +++ b/paddle/phi/kernels/argsort_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ArgsortKernel(const Context& dev_ctx, + const DenseTensor& input, + int axis, + bool descending, + DenseTensor* output, + DenseTensor* indices); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc new file mode 100644 index 0000000000..1e60847232 --- /dev/null +++ b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +static void FullAssign(Type input_height, + Type input_width, + int input_dim, + const DenseTensor* input, + const DenseTensor* indices, + T* t_out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + auto e_indices = EigenVector::Flatten(*indices); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(j)] = e_input(j); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + auto e_indices = EigenMatrix::Reshape(*indices, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(i, j)] = e_input(i, j); + } + } + } +} + +template +void ArgsortGradKernel(const Context& dev_ctx, + const DenseTensor& indices, + const DenseTensor& input, + const DenseTensor& out_grad, + int axis, + bool descending, + DenseTensor* in_grad) { + auto in_dims = indices.dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + dev_ctx.template Alloc(in_grad); + auto dxt = EigenVector::Flatten(*in_grad); + auto& place = *dev_ctx.eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (out_grad.numel() == 0) return; + + // Do full assign + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + + FullAssign(input_height, + input_width, + in_dims.size(), + &out_grad, + &indices, + in_grad->data()); + } else { + // If not full assign do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + phi::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + DenseTensor trans_dO; + trans_dO.Resize(trans_dims); + dev_ctx.template Alloc(&trans_dO); + DenseTensor trans_ind; + trans_ind.Resize(trans_dims); + dev_ctx.template Alloc(&trans_ind); + TransposeKernel(dev_ctx, out_grad, trans, &trans_dO); + TransposeKernel(dev_ctx, indices, trans, &trans_ind); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + DenseTensor tmp_out; + tmp_out.Resize(trans_dims); + T* t_out = dev_ctx.template Alloc(&tmp_out); + + FullAssign(input_height, + input_width, + in_dims.size(), + &trans_dO, + &trans_ind, + t_out); + + // transpose back + TransposeKernel(dev_ctx, tmp_out, trans, in_grad); + } +} + +} // namespace phi +PD_REGISTER_KERNEL(argsort_grad, + CPU, + ALL_LAYOUT, + phi::ArgsortGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/argsort_kernel.cc b/paddle/phi/kernels/cpu/argsort_kernel.cc new file mode 100644 index 0000000000..0e69afe38c --- /dev/null +++ b/paddle/phi/kernels/cpu/argsort_kernel.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +static void FullSort(Type input_height, + Type input_width, + int input_dim, + const DenseTensor* input, + T* t_out, + Type* t_indices, + bool descending) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(j), j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(i, j), j)); + } + } + std::sort(col_vec.begin(), + col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + return l.first > r.first; + else + return l.first < r.first; + }); + + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + j] = col_vec[j].first; + t_indices[i * input_width + j] = col_vec[j].second; + } + } +} + +template +void ArgsortKernel(const Context& dev_ctx, + const DenseTensor& input, + int axis, + bool descending, + DenseTensor* output, + DenseTensor* indices) { + auto in_dims = input.dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + T* out_data = dev_ctx.template Alloc(output); + + // Do full sort + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + int64_t* ids_data = dev_ctx.template Alloc(indices); + FullSort(input_height, + input_width, + in_dims.size(), + &input, + out_data, + ids_data, + descending); + } else { + // If not full sort do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + phi::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + DenseTensor trans_inp; + trans_inp.Resize(trans_dims); + dev_ctx.template Alloc(&trans_inp); + // Do transpose + TransposeKernel(dev_ctx, input, trans, &trans_inp); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + DenseTensor tmp_out; + tmp_out.Resize(trans_dims); + T* t_out = dev_ctx.template Alloc(&tmp_out); + + DenseTensor tmp_indices; + tmp_indices.Resize(trans_dims); + auto* t_ind = dev_ctx.template Alloc(&tmp_indices); + + FullSort(input_height, + input_width, + in_dims.size(), + &trans_inp, + t_out, + t_ind, + descending); + + dev_ctx.template Alloc(indices); + TransposeKernel(dev_ctx, tmp_indices, trans, indices); + // transpose back + TransposeKernel(dev_ctx, tmp_out, trans, output); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + argsort, CPU, ALL_LAYOUT, phi::ArgsortKernel, float, double, int, int64_t) { +} diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu new file mode 100644 index 0000000000..15bca474f5 --- /dev/null +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -0,0 +1,217 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_kernel.h" + +#include +#include +#include +#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +#ifdef __HIPCC__ +namespace rocprim { +namespace detail { +template <> +struct radix_key_codec_base + : radix_key_codec_integral {}; +} // namespace detail +} // namespace rocprim +#else +// set cub base traits in order to handle float16 +namespace cub { +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub +#endif + +namespace phi { + +template +static __global__ void FillFlattenGrad(const T* dO, + const IndType* indices, + int64_t size, + T* dX) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < size; i += stride) { + dX[indices[i]] = dO[i]; + } +} + +template +static __global__ void FillGrad(const T* dO, + const IndType* indices, + T* dX, + IndType num_rows, + IndType num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (IndType j = row_id; j < num_rows; j += gridDim.x) { + for (IndType i = col_id; i < num_cols; i += blockDim.x) { + dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i]; + } + } +} + +template +void ArgFullAssign(const phi::GPUContext& ctx, + const DenseTensor* dO, + const DenseTensor* indices, + DenseTensor* dX, + const IndType num_rows, + const IndType num_cols) { + auto cu_stream = ctx.stream(); + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + FillGrad<<>>(dO->data(), + indices->data(), + dX->data(), + num_rows, + num_cols); +} + +template +void ArgFlattenAssign(const phi::GPUContext& ctx, + const DenseTensor* dO, + const DenseTensor* indices, + int64_t size, + DenseTensor* dX) { + auto cu_stream = ctx.stream(); + + const int64_t block_size = + std::min(size, static_cast(ctx.GetMaxThreadsPerBlock())); + int64_t max_threads = ctx.GetMaxPhysicalThreadCount(); + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (size + block_size - 1) / block_size); + + FillFlattenGrad<<>>( + dO->data(), indices->data(), size, dX->data()); +} + +template +void ArgsortGradKernel(const Context& dev_ctx, + const DenseTensor& indices, + const DenseTensor& input, + const DenseTensor& out_grad, + int axis, + bool descending, + DenseTensor* in_grad) { + dev_ctx.template Alloc(in_grad); + if (out_grad.numel() == 0) return; + auto in_dims = in_grad->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + int64_t size = in_grad->numel(); + + // Parallel acceleration when the input size is equal to the length of the + // ‘axis’ dimension. + // Compared to 'special case for full sort' below, the gradient calculation + // is 10 times faster. + if (size == in_dims[axis]) { + ArgFlattenAssign(dev_ctx, &out_grad, &indices, size, in_grad); + return; + } + + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + ArgFullAssign( + dev_ctx, &out_grad, &indices, in_grad, input_height, input_width); + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + phi::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + DenseTensor trans_dO; + trans_dO.Resize(trans_dims); + dev_ctx.template Alloc(&trans_dO); + DenseTensor trans_ind; + trans_ind.Resize(trans_dims); + dev_ctx.template Alloc(&trans_ind); + TransposeKernel(dev_ctx, out_grad, trans, &trans_dO); + TransposeKernel(dev_ctx, indices, trans, &trans_ind); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + DenseTensor tmp_out; + tmp_out.Resize(trans_dims); + dev_ctx.template Alloc(&tmp_out); + + ArgFullAssign( + dev_ctx, &trans_dO, &trans_ind, &tmp_out, input_height, input_width); + + // transpose back + TransposeKernel(dev_ctx, tmp_out, trans, in_grad); + return; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(argsort_grad, + GPU, + ALL_LAYOUT, + phi::ArgsortGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu new file mode 100644 index 0000000000..6a9c1e2759 --- /dev/null +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -0,0 +1,310 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_kernel.h" + +#include +#include +#include +#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +#ifdef __HIPCC__ +namespace rocprim { +namespace detail { +template <> +struct radix_key_codec_base + : radix_key_codec_integral {}; +} // namespace detail +} // namespace rocprim +#else +// set cub base traits in order to handle float16 +namespace cub { +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub +#endif + +namespace phi { + +// Iter for move to next row +struct SegmentOffsetIter { + EIGEN_DEVICE_FUNC + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +template +static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } + } +} + +// Sort by flag descending, True: descending. False: Ascending. +// Default is false. +template +void ArgFullSort(const phi::GPUContext& ctx, + const DenseTensor* input, + DenseTensor* output, + DenseTensor* indices, + const IndType num_rows, + const IndType num_cols, + const bool descending) { + auto cu_stream = ctx.stream(); + DenseTensor input_indices; + const std::vector dims = {num_rows, num_cols}; + auto dim = phi::make_ddim(dims); + input_indices.Resize(dim); + ctx.template Alloc(&input_indices); + size_t temp_storage_bytes = -1; + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + // Init a index array + FillIndex<<>>( + input_indices.data(), num_rows, num_cols); + + T* sorted_out_ptr; + IndType* sorted_indices_ptr; + const T* inp = input->data(); + T* out = ctx.template Alloc(output); + IndType* ind = ctx.template Alloc(indices); + sorted_out_ptr = out; + sorted_indices_ptr = ind; + + // create iter for counting input + cub::CountingInputIterator counting_iter(0); + // segment_offset is used for move to next row + cub::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); + + gpuError_t err; + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } else { + err = + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } + PADDLE_ENFORCE_GPU_SUCCESS(err); + + DenseTensor temp_storage; + int64_t temp_size = temp_storage_bytes; + temp_storage.Resize({temp_size}); + ctx.template Alloc(&temp_storage); + + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } else { + err = + cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data(), + temp_storage_bytes, + inp, + sorted_out_ptr, + input_indices.data(), + sorted_indices_ptr, + num_cols * num_rows, + num_rows, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + } + + PADDLE_ENFORCE_GPU_SUCCESS(err); +} + +template +void ArgsortKernel(const Context& dev_ctx, + const DenseTensor& input, + int axis, + bool descending, + DenseTensor* output, + DenseTensor* indices) { + auto in_dims = input.dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + const T* in_data = input.data(); + auto size = input.numel(); + T* out_data = dev_ctx.template Alloc(output); + int64_t* ids_data = dev_ctx.template Alloc(indices); + + // Use thrust for parallel acceleration when the input size is equal to the + // length of the ‘axis’ dimension. + // Compared to the following 'Special case for full sort', ascending sort is + // 34 times faster and descending sort is 31 times faster. + if (size == in_dims[axis]) { + thrust::sequence(thrust::device, ids_data, ids_data + size); + thrust::copy(thrust::device, in_data, in_data + size, out_data); + thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data); + if (descending) { + thrust::reverse(thrust::device, out_data, out_data + size); + thrust::reverse(thrust::device, ids_data, ids_data + size); + } + return; + } + + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + ArgFullSort(dev_ctx, + &input, + output, + indices, + input_height, + input_width, + descending); + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + phi::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + DenseTensor trans_inp; + trans_inp.Resize(trans_dims); + T* trans_inp_data = dev_ctx.template Alloc(&trans_inp); + // Do transpose + TransposeKernel(dev_ctx, input, trans, &trans_inp); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + DenseTensor tmp_out; + tmp_out.Resize(trans_dims); + dev_ctx.template Alloc(&tmp_out); + + DenseTensor tmp_indices; + // temp indices for sorting + tmp_indices.Resize(trans_dims); + dev_ctx.template Alloc(&tmp_indices); + dev_ctx.template Alloc(indices); + + ArgFullSort(dev_ctx, + &trans_inp, + &tmp_out, + &tmp_indices, + input_height, + input_width, + descending); + + TransposeKernel(dev_ctx, tmp_indices, trans, indices); + // transpose back + TransposeKernel(dev_ctx, tmp_out, trans, output); + return; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(argsort, + GPU, + ALL_LAYOUT, + phi::ArgsortKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/argsort_sig.cc b/paddle/phi/ops/compat/argsort_sig.cc new file mode 100644 index 0000000000..62133a441f --- /dev/null +++ b/paddle/phi/ops/compat/argsort_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ArgsortGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("argsort_grad", + {"Indices", "X", GradVarName("Out")}, + {"axis", "descending"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(argsort_grad, phi::ArgsortGradOpArgumentMapping); -- GitLab