From 329b095ee1f9baed065b1100bb0ce3959b5b1e15 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 10 Mar 2022 14:42:57 +0800 Subject: [PATCH] [Phi]Move topk kernel to phi (#40064) * first commit * cpu kernel * first version * fix compile error * fix compile error * delete v2 * fix * fix * add alias * fix * fix * fix * fix error * fix * fix * fix * fix format --- paddle/fluid/operators/kthvalue_op.cu | 1 - paddle/fluid/operators/mode_op.cu | 1 - paddle/fluid/operators/top_k_function_cuda.h | 13 + paddle/fluid/operators/top_k_v2_op.cc | 15 +- paddle/fluid/operators/top_k_v2_op.cu | 296 ---------------- paddle/fluid/operators/top_k_v2_op.h | 335 ------------------- paddle/fluid/operators/top_k_v2_op_mlu.cc | 2 +- paddle/fluid/operators/top_k_v2_op_npu.cc | 2 +- paddle/fluid/operators/top_k_v2_op_xpu.cc | 2 +- paddle/phi/core/compat/op_utils.h | 4 +- paddle/phi/kernels/cpu/top_k_grad_kernel.cc | 151 +++++++++ paddle/phi/kernels/cpu/top_k_kernel.cc | 230 +++++++++++++ paddle/phi/kernels/funcs/math_function.h | 38 +++ paddle/phi/kernels/gpu/top_k_grad_kernel.cu | 87 +++++ paddle/phi/kernels/gpu/top_k_kernel.cu | 264 +++++++++++++++ paddle/phi/kernels/top_k_grad_kernel.h | 32 ++ paddle/phi/kernels/top_k_kernel.h | 32 ++ paddle/phi/ops/compat/top_k_sig.cc | 42 +++ 18 files changed, 897 insertions(+), 650 deletions(-) delete mode 100644 paddle/fluid/operators/top_k_v2_op.cu delete mode 100644 paddle/fluid/operators/top_k_v2_op.h create mode 100644 paddle/phi/kernels/cpu/top_k_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/top_k_kernel.cc create mode 100644 paddle/phi/kernels/gpu/top_k_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/top_k_kernel.cu create mode 100644 paddle/phi/kernels/top_k_grad_kernel.h create mode 100644 paddle/phi/kernels/top_k_kernel.h create mode 100644 paddle/phi/ops/compat/top_k_sig.cc diff --git a/paddle/fluid/operators/kthvalue_op.cu b/paddle/fluid/operators/kthvalue_op.cu index 4f30c58d375..f6f56f70f1a 100644 --- a/paddle/fluid/operators/kthvalue_op.cu +++ b/paddle/fluid/operators/kthvalue_op.cu @@ -16,7 +16,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/kthvalue_op.h" #include "paddle/fluid/operators/top_k_function_cuda.h" -#include "paddle/fluid/operators/top_k_v2_op.h" #ifdef __NVCC__ #include "cub/cub.cuh" #endif diff --git a/paddle/fluid/operators/mode_op.cu b/paddle/fluid/operators/mode_op.cu index afb949d3374..2bacda8afb0 100644 --- a/paddle/fluid/operators/mode_op.cu +++ b/paddle/fluid/operators/mode_op.cu @@ -24,7 +24,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mode_op.h" #include "paddle/fluid/operators/top_k_function_cuda.h" -#include "paddle/fluid/operators/top_k_v2_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index d60976928e0..80c9935057c 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -51,6 +51,19 @@ namespace operators { using Tensor = framework::Tensor; +inline void GetDims(const phi::DDim& dim, int axis, int* pre, int* n, + int* post) { + *pre = 1; + *post = 1; + *n = dim[axis]; + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } +} + struct SegmentOffsetIter { EIGEN_DEVICE_FUNC explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} diff --git a/paddle/fluid/operators/top_k_v2_op.cc b/paddle/fluid/operators/top_k_v2_op.cc index 810afc901df..d1add111e1d 100644 --- a/paddle/fluid/operators/top_k_v2_op.cc +++ b/paddle/fluid/operators/top_k_v2_op.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/top_k_v2_op.h" #include +#include "paddle/fluid/framework/op_registry.h" + namespace paddle { namespace operators { @@ -173,15 +174,3 @@ REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker, ops::TopkV2GradOpMaker); REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad); - -REGISTER_OP_CPU_KERNEL(top_k_v2, - ops::TopkV2Kernel, - ops::TopkV2Kernel, - ops::TopkV2Kernel, - ops::TopkV2Kernel) - -REGISTER_OP_CPU_KERNEL( - top_k_v2_grad, ops::TopkV2GradKernel, - ops::TopkV2GradKernel, - ops::TopkV2GradKernel, - ops::TopkV2GradKernel) diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu deleted file mode 100644 index 84d8eef53bf..00000000000 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/top_k_function_cuda.h" -#include "paddle/fluid/operators/top_k_v2_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -#define FIXED_BLOCK_DIM_BASE(dim, ...) \ - case (dim): { \ - constexpr auto kBlockDim = (dim); \ - __VA_ARGS__; \ - } break - -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) - -template -class TopkV2OpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::InvalidArgument( - "It must use CUDAPlace, you must check your device set.")); - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* indices = ctx.Output("Indices"); - - // get the attributes - int k = static_cast(ctx.Attr("k")); - int axis = static_cast(ctx.Attr("axis")); - const bool& sorted = static_cast(ctx.Attr("sorted")); - const bool& largest = static_cast(ctx.Attr("largest")); - - // get the input dims - const auto& in_dims = input->dims(); - // calcluate the real axis - if (axis < 0) axis += in_dims.size(); - - auto* k_t = ctx.Input("K"); - if (k_t) { - Tensor k_host; - framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host); - k = k_host.data()[0]; - framework::DDim output_dims = output->dims(); - output_dims[axis] = k; - output->Resize(output_dims); - indices->Resize(output_dims); - } - - const auto& out_dims = output->dims(); - - const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - - if (axis == in_dims.size() - 1) { - // if get the topK from the last axis - const int64_t& input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t& input_width = in_dims[in_dims.size() - 1]; - const auto& dev_ctx = ctx.cuda_device_context(); - - if (k > input_width) k = input_width; - - // The conclusion is drawn from the data through multiple sets of - // statistics - if (input_width >= 128 && k >= input_width * 0.75) { - if (SortTopk(dev_ctx, input, input_width, input_height, k, output, - indices, largest)) { - // Successed, return. - return; - } else { - LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " - "default topk kernel."; - } - } - - // NOTE: pass lds and dim same to input width. - // NOTE: old matrix implementation of stride is different to eigen. - const int kMaxHeight = 2048; - int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - switch (GetDesiredBlockDim(input_width)) { -#ifdef PADDLE_WITH_HIP - FIXED_BLOCK_DIM( - KeMatrixTopK<<>>( - output_data, k, indices_data, input_data, input_width, - input_width, static_cast(k), gridx, input_height, - largest)); -#else - FIXED_BLOCK_DIM( - KeMatrixTopK<<>>( - output_data, k, indices_data, input_data, input_width, - input_width, static_cast(k), gridx, input_height, - largest)); -#endif - default: - PADDLE_THROW(platform::errors::Fatal( - "the input data shape has error in the topk cuda kernel.")); - } - } else { - // if get topK not from the last axis, will tranpose the tensor and get - // TopK - - // first step, prepare the trans args for the tranpose - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.emplace_back(i); - } - trans.emplace_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.emplace_back(i); - } - trans.emplace_back(axis); - - framework::DDim trans_dims(in_dims); - framework::DDim trans_out_dims(output->dims()); - for (int i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - trans_out_dims[i] = out_dims[trans[i]]; - } - // second step, tranpose the input - Tensor trans_input; - trans_input.mutable_data(trans_dims, ctx.GetPlace()); - int ndims = trans.size(); - const auto& dev_ctx = ctx.cuda_device_context(); - TransCompute(ndims, dev_ctx, *input, - &trans_input, trans); - // third step, calcluate the topk - // allocate the tmp cuda memory for the tmp result - Tensor trans_ind; - trans_ind.mutable_data(trans_out_dims, ctx.GetPlace()); - Tensor trans_out; - trans_out.mutable_data(trans_out_dims, ctx.GetPlace()); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - if (k > input_width) k = input_width; - - // The conclusion is drawn from the data through multiple sets of - // statistics - if (input_width >= 128 && k >= input_width * 0.75) { - if (SortTopk(dev_ctx, &trans_input, input_width, input_height, k, - &trans_out, &trans_ind, largest)) { - // last step, tranpose back the indices and output - TransCompute( - ndims, dev_ctx, trans_ind, indices, trans); - TransCompute( - ndims, dev_ctx, trans_out, output, trans); - return; - } else { - LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " - "default topk kernel."; - } - } - - const int kMaxHeight = 2048; - int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - switch (GetDesiredBlockDim(input_width)) { -#ifdef PADDLE_WITH_HIP - FIXED_BLOCK_DIM( - KeMatrixTopK<<>>( - trans_out.data(), k, trans_ind.data(), - trans_input.data(), input_width, input_width, - static_cast(k), gridx, input_height, largest)); -#else - FIXED_BLOCK_DIM( - KeMatrixTopK<<>>( - trans_out.data(), k, trans_ind.data(), - trans_input.data(), input_width, input_width, - static_cast(k), gridx, input_height, largest)); -#endif - default: - PADDLE_THROW(platform::errors::Fatal( - "the input data shape has error in the topk cuda kernel.")); - } - - // last step, tranpose back the indices and output - TransCompute( - ndims, dev_ctx, trans_ind, indices, trans); - TransCompute(ndims, dev_ctx, trans_out, - output, trans); - } - } -}; - -#undef FIXED_BLOCK_DIM_BASE -#undef FIXED_BLOCK_DIM -template -class TopkV2OpGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(context.GetPlace()), true, - platform::errors::InvalidArgument( - "It must use CUDAPlace, you must check your device set.")); - auto* x = context.Input("X"); - auto* out_grad = context.Input(framework::GradVarName("Out")); - auto* indices = context.Input("Indices"); - auto* x_grad = context.Output(framework::GradVarName("X")); - int axis = context.Attr("axis"); - - const auto& in_dims = x->dims(); - const auto& out_dims = indices->dims(); - - // get the real the axis and the k - if (axis < 0) axis += in_dims.size(); - const int& k = out_dims[axis]; - const int& raw_height = in_dims[axis]; - - // allocate the cuda memory for the x_grad - T* x_grad_data = x_grad->mutable_data(context.GetPlace()); - const T* out_grad_data = out_grad->data(); - const int64_t* indices_data = indices->data(); - - int pre, n, post; - GetDims(in_dims, axis, &pre, &n, &post); - - // calcluate the block and grid num - auto& dev_ctx = context.cuda_device_context(); - auto ComputeBlockSize = [](int col) { - if (col > 512) - return 1024; - else if (col > 256 && col <= 512) - return 512; - else if (col > 128 && col <= 256) - return 256; - else if (col > 64 && col <= 128) - return 128; - else - return 64; - }; - int block_size = ComputeBlockSize(post * k); - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1); - int grid_size = std::min(max_blocks, pre); - - // lanuch the cuda kernel to assign the grad - AssignGradWithAxis<<>>( - out_grad_data, indices_data, x_grad_data, pre, post, n, k); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - top_k_v2, - paddle::operators::TopkV2OpCUDAKernel, - paddle::operators::TopkV2OpCUDAKernel, - paddle::operators::TopkV2OpCUDAKernel, - paddle::operators::TopkV2OpCUDAKernel, - paddle::operators::TopkV2OpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL( - top_k_v2_grad, paddle::operators::TopkV2OpGradCUDAKernel< - paddle::platform::CUDADeviceContext, float>, - paddle::operators::TopkV2OpGradCUDAKernel< - paddle::platform::CUDADeviceContext, double>, - paddle::operators::TopkV2OpGradCUDAKernel< - paddle::platform::CUDADeviceContext, int>, - paddle::operators::TopkV2OpGradCUDAKernel< - paddle::platform::CUDADeviceContext, int64_t>, - paddle::operators::TopkV2OpGradCUDAKernel< - paddle::platform::CUDADeviceContext, paddle::platform::float16>); diff --git a/paddle/fluid/operators/top_k_v2_op.h b/paddle/fluid/operators/top_k_v2_op.h deleted file mode 100644 index a808207476f..00000000000 --- a/paddle/fluid/operators/top_k_v2_op.h +++ /dev/null @@ -1,335 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -/* - The reason why we need the topk v2 is because the compatibility. We redefine - the NaN is maximum value - in the process of comparing. If do not add the topk v2, will affect the - inference result of model that traing - by the older version paddlepaddle. -*/ - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/top_k_op.h" -#include "paddle/fluid/operators/transpose_op.h" - -namespace paddle { -namespace operators { - -inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, - int* post) { - *pre = 1; - *post = 1; - *n = dim[axis]; - for (int i = 0; i < axis; ++i) { - (*pre) *= dim[i]; - } - for (int i = axis + 1; i < dim.size(); ++i) { - (*post) *= dim[i]; - } -} - -template -static void FullTopK(Type input_height, Type input_width, int input_dim, - const framework::Tensor* input, T* t_out, Type* t_indices, - const int& k, const bool& largest, const bool& sorted) { - // when the k is small, will the partial sort - bool partial_sort_flag = (k * 64) < input_width; - -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - // Eigen::DSizes flat2dims(input_height, input_width); - for (Type i = 0; i < input_height; ++i) { - std::vector> col_vec; - col_vec.reserve(input_width); - if (input_dim == 1) { - auto e_input = framework::EigenVector::Flatten(*input); - for (Type j = 0; j < input_width; ++j) { - col_vec.emplace_back(std::pair(e_input(j), j)); - } - } else { - auto e_input = framework::EigenMatrix::Reshape(*input, input_dim - 1); - for (Type j = 0; j < input_width; ++j) { - col_vec.emplace_back(std::pair(e_input(i, j), j)); - } - } - if (partial_sort_flag) { - std::partial_sort( - col_vec.begin(), col_vec.begin() + k, col_vec.end(), - [&largest](const std::pair& l, const std::pair& r) { - if (largest) { - return (std::isnan(static_cast(l.first)) && - !std::isnan(static_cast(r.first))) || - (l.first > r.first); - } else { - return (!std::isnan(static_cast(l.first)) && - std::isnan(static_cast(r.first))) || - (l.first < r.first); - } - }); - } else { - // use the nth-element to get the K-larger or K-small element - if (largest) { - std::nth_element( - col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(), - [](const std::pair& l, const std::pair& r) { - return (std::isnan(static_cast(l.first)) && - !std::isnan(static_cast(r.first))) || - (l.first > r.first); - }); - // the nth-element will get the unorder elements, sort the element - if (sorted) { - std::sort(col_vec.begin(), col_vec.begin() + k - 1, - [&largest](const std::pair& l, - const std::pair& r) { - return (std::isnan(static_cast(l.first)) && - !std::isnan(static_cast(r.first))) || - (l.first > r.first); - }); - } - } else { - std::nth_element( - col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(), - [](const std::pair& l, const std::pair& r) { - return (!std::isnan(static_cast(l.first)) && - std::isnan(static_cast(r.first))) || - (l.first < r.first); - }); - // the nth-element will get the unorder elements, sort the element - if (sorted) { - std::sort( - col_vec.begin(), col_vec.begin() + k - 1, - [](const std::pair& l, const std::pair& r) { - return (!std::isnan(static_cast(l.first)) && - std::isnan(static_cast(r.first))) || - (l.first < r.first); - }); - } - } - } - for (Type j = 0; j < k; ++j) { - t_out[i * k + j] = col_vec[j].first; - t_indices[i * k + j] = col_vec[j].second; - } - } -} - -template -static void FullTopKAssign(const Type& input_height, const Type& input_width, - const int& input_dim, const framework::Tensor* input, - const framework::Tensor* indices, T* output_data, - const int& k) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (Type i = 0; i < input_height; ++i) { - if (input_dim == 1) { - auto e_input = framework::EigenVector::Flatten(*input); - auto e_indices = framework::EigenVector::Flatten(*indices); - for (Type j = 0; j < k; ++j) { - output_data[i * input_width + e_indices(j)] = e_input(j); - } - } else { - auto e_input = framework::EigenMatrix::Reshape(*input, input_dim - 1); - auto e_indices = - framework::EigenMatrix::Reshape(*indices, input_dim - 1); - for (Type j = 0; j < k; ++j) { - output_data[i * input_width + e_indices(i, j)] = e_input(i, j); - } - } - } -} - -template -class TopkV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - // Get the top k elements of each row of input tensor - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - auto* indices = context.Output("Indices"); - const auto& in_dims = input->dims(); - int k = static_cast(context.Attr("k")); - const auto& sorted = static_cast(context.Attr("sorted")); - const auto& largest = static_cast(context.Attr("largest")); - - // axis < 0, cacluate the real axis - int axis = static_cast(context.Attr("axis")); - if (axis < 0) axis += in_dims.size(); - - // if K tensor is not null, will the use K tesnor as k - auto* k_t = context.Input("K"); - if (k_t) { - k = k_t->data()[0]; - framework::DDim output_dims = output->dims(); - // accroding to axis to set K value in the dim - output_dims[axis] = k; - output->Resize(output_dims); - indices->Resize(output_dims); - } - - T* output_data = output->mutable_data(context.GetPlace()); - int64_t* indices_data = indices->mutable_data(context.GetPlace()); - const auto& out_dims = output->dims(); - if (axis + 1 == in_dims.size()) { - const int64_t& input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t& input_width = in_dims[in_dims.size() - 1]; - FullTopK(input_height, input_width, in_dims.size(), input, - output_data, indices_data, k, largest, sorted); - } else { - // if the topk dims is not last dim, will tranpose and do topk - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.emplace_back(i); - } - trans.push_back(in_dims.size() - 1); - for (int i = axis + 1; i < in_dims.size() - 1; i++) { - trans.emplace_back(i); - } - trans.emplace_back(axis); - - // get the trans input_dims, out_dims - framework::DDim trans_dims(in_dims); - framework::DDim trans_out_dims(output->dims()); - for (size_t i = 0; i < trans.size(); i++) { - trans_dims[i] = in_dims[trans[i]]; - } - for (size_t i = 0; i < trans.size(); i++) { - trans_out_dims[i] = out_dims[trans[i]]; - } - - Tensor trans_inp; - trans_inp.mutable_data(trans_dims, context.GetPlace()); - int ndims = trans.size(); - auto& dev_context = - context.template device_context(); - - // transpose the input value - TransCompute(ndims, dev_context, *input, - &trans_inp, trans); - - const int64_t input_height = - phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); - const int64_t input_width = trans_dims[trans_dims.size() - 1]; - - // Allocate the temp tensor to the save the topk indices, values - Tensor tmp_out; - T* t_out = tmp_out.mutable_data(trans_out_dims, context.GetPlace()); - Tensor tmp_indices; - auto* t_ind = - tmp_indices.mutable_data(trans_out_dims, context.GetPlace()); - - // get the TopK value - FullTopK(input_height, input_width, in_dims.size(), - &trans_inp, t_out, t_ind, k, largest, sorted); - // transpose back - TransCompute( - ndims, dev_context, tmp_indices, indices, trans); - TransCompute(ndims, dev_context, tmp_out, - output, trans); - } - } -}; - -template -class TopkV2GradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out_grad = context.Input(framework::GradVarName("Out")); - auto* indices = context.Input("Indices"); - auto* x_grad = context.Output(framework::GradVarName("X")); - int axis = static_cast(context.Attr("axis")); - - const auto& in_dims = x->dims(); - const auto& out_dims = indices->dims(); - - // axis < 0, get the real axis - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const size_t& k = out_dims[axis]; - - T* x_grad_data = x_grad->mutable_data(context.GetPlace()); - if (axis + 1 == in_dims.size()) { - // allocate the memory for the input_grad - - // assign the out_grad to input_grad directly - const int64_t input_height = - phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); - const int64_t input_width = in_dims[in_dims.size() - 1]; - - // init the output grad with 0, because some input elements has no grad - memset(x_grad_data, 0, x_grad->numel() * sizeof(T)); - // Assign the output_grad to input_grad - FullTopKAssign(input_height, input_width, in_dims.size(), out_grad, - indices, x_grad_data, k); - } else { - // can not assign grad to input_grad, must do the transpose - std::vector trans; - for (int i = 0; i < axis; i++) { - trans.emplace_back(i); - } - trans.emplace_back(out_dims.size() - 1); - for (int i = axis + 1; i < out_dims.size() - 1; i++) { - trans.emplace_back(i); - } - trans.emplace_back(axis); - framework::DDim trans_dims(out_dims); - framework::DDim trans_in_dims(in_dims); - for (size_t i = 0; i < trans.size(); i++) { - trans_dims[i] = out_dims[trans[i]]; - trans_in_dims[i] = in_dims[trans[i]]; - } - // transpose the out_grad, indices - Tensor trans_dO; - trans_dO.mutable_data(trans_dims, context.GetPlace()); - Tensor trans_ind; - trans_ind.mutable_data(trans_dims, context.GetPlace()); - int ndims = trans.size(); - auto& dev_context = - context.template device_context(); - - // Do transpose - TransCompute(ndims, dev_context, *out_grad, - &trans_dO, trans); - TransCompute( - ndims, dev_context, *indices, &trans_ind, trans); - const int64_t input_height = phi::product( - phi::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1)); - const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1]; - - // Assign the out_grad to tranpose input_grad - Tensor tmp_out; - T* t_out = tmp_out.mutable_data(trans_in_dims, context.GetPlace()); - memset(t_out, 0, x_grad->numel() * sizeof(T)); - - FullTopKAssign(input_height, input_width, in_dims.size(), - &trans_dO, &trans_ind, t_out, k); - - // Transpose back - TransCompute(ndims, dev_context, tmp_out, - x_grad, trans); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/top_k_v2_op_mlu.cc b/paddle/fluid/operators/top_k_v2_op_mlu.cc index 5b8a6b3e754..caaae02124c 100644 --- a/paddle/fluid/operators/top_k_v2_op_mlu.cc +++ b/paddle/fluid/operators/top_k_v2_op_mlu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/top_k_v2_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index e1107063883..dff5c2d3f39 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/top_k_v2_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/top_k_v2_op_xpu.cc b/paddle/fluid/operators/top_k_v2_op_xpu.cc index 49daac2ff0d..4d9c39be92e 100644 --- a/paddle/fluid/operators/top_k_v2_op_xpu.cc +++ b/paddle/fluid/operators/top_k_v2_op_xpu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include -#include "paddle/fluid/operators/top_k_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/transpose_op.h" #include "xpu/refactor/math.h" diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index 1ab718c0794..fea79766a6b 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -52,7 +52,9 @@ const std::unordered_set deprecated_op_names({"diag", "reshape_grad", "expand", "expand_grad", - "sum"}); + "sum", + "top_k", + "top_k_grad"}); class DefaultKernelSignatureMap { public: diff --git a/paddle/phi/kernels/cpu/top_k_grad_kernel.cc b/paddle/phi/kernels/cpu/top_k_grad_kernel.cc new file mode 100644 index 00000000000..582ee1157cc --- /dev/null +++ b/paddle/phi/kernels/cpu/top_k_grad_kernel.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/top_k_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +static void FullTopKAssign(const Type& input_height, + const Type& input_width, + const int& input_dim, + const DenseTensor* input, + const DenseTensor* indices, + T* output_data, + const int& k) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + auto e_indices = EigenVector::Flatten(*indices); + for (Type j = 0; j < k; ++j) { + output_data[i * input_width + e_indices(j)] = e_input(j); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + auto e_indices = EigenMatrix::Reshape(*indices, input_dim - 1); + for (Type j = 0; j < k; ++j) { + output_data[i * input_width + e_indices(i, j)] = e_input(i, j); + } + } + } +} + +template +void TopkGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& indices, + int k, + int axis, + bool largest, + bool sorted, + DenseTensor* x_grad) { + const auto& in_dims = x.dims(); + const auto& out_dims = indices.dims(); + + // axis < 0, get the real axis + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + T* x_grad_data = dev_ctx.template Alloc(x_grad); + if (axis + 1 == in_dims.size()) { + // allocate the memory for the input_grad + + // assign the out_grad to input_grad directly + const int64_t input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + + // init the output grad with 0, because some input elements has no grad + memset(x_grad_data, 0, x_grad->numel() * sizeof(T)); + // Assign the output_grad to input_grad + FullTopKAssign(input_height, + input_width, + in_dims.size(), + &out_grad, + &indices, + x_grad_data, + k); + } else { + // can not assign grad to input_grad, must do the transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.emplace_back(i); + } + trans.emplace_back(out_dims.size() - 1); + for (int i = axis + 1; i < out_dims.size() - 1; i++) { + trans.emplace_back(i); + } + trans.emplace_back(axis); + phi::DDim trans_dims(out_dims); + phi::DDim trans_in_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = out_dims[trans[i]]; + trans_in_dims[i] = in_dims[trans[i]]; + } + // transpose the out_grad, indices + DenseTensor trans_dO; + DenseTensor trans_ind; + trans_dO.Resize(trans_dims); + trans_ind.Resize(trans_dims); + dev_ctx.template Alloc(&trans_dO); + dev_ctx.template Alloc(&trans_ind); + int ndims = trans.size(); + + // Do transpose + funcs::TransCompute( + ndims, dev_ctx, out_grad, &trans_dO, trans); + funcs::TransCompute( + ndims, dev_ctx, indices, &trans_ind, trans); + const int64_t input_height = phi::product( + phi::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1)); + const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1]; + + // Assign the out_grad to tranpose input_grad + DenseTensor tmp_out; + tmp_out.Resize(trans_in_dims); + T* t_out = dev_ctx.template Alloc(&tmp_out); + memset(t_out, 0, x_grad->numel() * sizeof(T)); + + FullTopKAssign(input_height, + input_width, + in_dims.size(), + &trans_dO, + &trans_ind, + t_out, + k); + + // Transpose back + funcs::TransCompute( + ndims, dev_ctx, tmp_out, x_grad, trans); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(top_k_grad, + CPU, + ALL_LAYOUT, + phi::TopkGradKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/top_k_kernel.cc b/paddle/phi/kernels/cpu/top_k_kernel.cc new file mode 100644 index 00000000000..4ac16667ce2 --- /dev/null +++ b/paddle/phi/kernels/cpu/top_k_kernel.cc @@ -0,0 +1,230 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/top_k_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +static void FullTopK(Type input_height, + Type input_width, + int input_dim, + const DenseTensor* input, + T* t_out, + Type* t_indices, + const int& k, + const bool& largest, + const bool& sorted) { + // when the k is small, will the partial sort + bool partial_sort_flag = (k * 64) < input_width; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + for (Type j = 0; j < input_width; ++j) { + col_vec.emplace_back(std::pair(e_input(j), j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + col_vec.emplace_back(std::pair(e_input(i, j), j)); + } + } + if (partial_sort_flag) { + std::partial_sort( + col_vec.begin(), + col_vec.begin() + k, + col_vec.end(), + [&largest](const std::pair& l, const std::pair& r) { + if (largest) { + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + } else { + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + } + }); + } else { + // use the nth-element to get the K-larger or K-small element + if (largest) { + std::nth_element( + col_vec.begin(), + col_vec.begin() + k - 1, + col_vec.end(), + [](const std::pair& l, const std::pair& r) { + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + }); + // the nth-element will get the unorder elements, sort the element + if (sorted) { + std::sort(col_vec.begin(), + col_vec.begin() + k - 1, + [&largest](const std::pair& l, + const std::pair& r) { + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + }); + } + } else { + std::nth_element( + col_vec.begin(), + col_vec.begin() + k - 1, + col_vec.end(), + [](const std::pair& l, const std::pair& r) { + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + // the nth-element will get the unorder elements, sort the element + if (sorted) { + std::sort( + col_vec.begin(), + col_vec.begin() + k - 1, + [](const std::pair& l, const std::pair& r) { + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + } + } + } + for (Type j = 0; j < k; ++j) { + t_out[i * k + j] = col_vec[j].first; + t_indices[i * k + j] = col_vec[j].second; + } + } +} + +template +void TopkKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& k_scalar, + int axis, + bool largest, + bool sorted, + DenseTensor* out, + DenseTensor* indices) { + const auto* input = &x; + // Get the top k elements of each row of input tensor + const auto& in_dims = input->dims(); + + // axis < 0, cacluate the real axis + if (axis < 0) { + axis += in_dims.size(); + } + + int k = k_scalar.to(); + if (k_scalar.FromTensor()) { + auto out_dims = out->dims(); + // accroding to axis to set K value in the dim + out_dims[axis] = k; + out->Resize(out_dims); + indices->Resize(out_dims); + } + + T* out_data = dev_ctx.template Alloc(out); + int64_t* indices_data = dev_ctx.template Alloc(indices); + const auto& out_dims = out->dims(); + if (axis + 1 == in_dims.size()) { + const int64_t& input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t& input_width = in_dims[in_dims.size() - 1]; + FullTopK(input_height, + input_width, + in_dims.size(), + input, + out_data, + indices_data, + k, + largest, + sorted); + } else { + // if the topk dims is not last dim, will tranpose and do topk + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.emplace_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.emplace_back(i); + } + trans.emplace_back(axis); + + // get the trans input_dims, out_dims + phi::DDim trans_dims(in_dims); + phi::DDim trans_out_dims(out->dims()); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + for (size_t i = 0; i < trans.size(); i++) { + trans_out_dims[i] = out_dims[trans[i]]; + } + + DenseTensor trans_inp; + trans_inp.Resize(trans_dims); + dev_ctx.template Alloc(&trans_inp); + int ndims = trans.size(); + + // transpose the input value + funcs::TransCompute( + ndims, dev_ctx, *input, &trans_inp, trans); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + // Allocate the temp tensor to the save the topk indices, values + DenseTensor tmp_out; + DenseTensor tmp_indices; + tmp_out.Resize(trans_out_dims); + tmp_indices.Resize(trans_out_dims); + T* t_out = dev_ctx.template Alloc(&tmp_out); + auto* t_ind = dev_ctx.template Alloc(&tmp_indices); + + // get the TopK value + FullTopK(input_height, + input_width, + in_dims.size(), + &trans_inp, + t_out, + t_ind, + k, + largest, + sorted); + // transpose back + funcs::TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + funcs::TransCompute( + ndims, dev_ctx, tmp_out, out, trans); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + top_k, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {} diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index 8e1a4cdd1a9..b735587d3d5 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -125,5 +125,43 @@ struct TensorSetConstantXPU { }; #endif +template +inline void TransCompute(const int dim, + const Context& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + switch (dim) { + case 1: + Transpose trans1; + trans1(dev_ctx, in, out, axis); + break; + case 2: + Transpose trans2; + trans2(dev_ctx, in, out, axis); + break; + case 3: + Transpose trans3; + trans3(dev_ctx, in, out, axis); + break; + case 4: + Transpose trans4; + trans4(dev_ctx, in, out, axis); + break; + case 5: + Transpose trans5; + trans5(dev_ctx, in, out, axis); + break; + case 6: + Transpose trans6; + trans6(dev_ctx, in, out, axis); + break; + default: + // for dim >= 7 situation + TransposeNormal trans_normal; + trans_normal(dev_ctx, in, out, axis); + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/top_k_grad_kernel.cu b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu new file mode 100644 index 00000000000..b0b45223489 --- /dev/null +++ b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/top_k_grad_kernel.h" + +#include "paddle/fluid/operators/top_k_function_cuda.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +namespace ops = paddle::operators; + +template +void TopkGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& indices, + int k, + int axis, + bool largest, + bool sorted, + DenseTensor* x_grad) { + const auto& in_dims = x.dims(); + const auto& out_dims = indices.dims(); + + // get the real the axis and the k + if (axis < 0) { + axis += in_dims.size(); + } + const int& raw_height = in_dims[axis]; + + // allocate the cuda memory for the x_grad + T* x_grad_data = dev_ctx.template Alloc(x_grad); + const T* out_grad_data = out_grad.data(); + const int64_t* indices_data = indices.data(); + + int pre, n, post; + ops::GetDims(in_dims, axis, &pre, &n, &post); + + // calcluate the block and grid num + auto ComputeBlockSize = [](int col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + int block_size = ComputeBlockSize(post * k); + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1); + int grid_size = std::min(max_blocks, pre); + + // lanuch the cuda kernel to assign the grad + ops::AssignGradWithAxis< + T><<>>( + out_grad_data, indices_data, x_grad_data, pre, post, n, k); +} + +} // namespace phi + +PD_REGISTER_KERNEL(top_k_grad, + GPU, + ALL_LAYOUT, + phi::TopkGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu new file mode 100644 index 00000000000..4e9aa88c6cb --- /dev/null +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -0,0 +1,264 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/top_k_kernel.h" + +#include "paddle/fluid/operators/top_k_function_cuda.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +namespace ops = paddle::operators; + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + +template +void TopkKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& k_scalar, + int axis, + bool largest, + bool sorted, + DenseTensor* out, + DenseTensor* indices) { + const auto* input = &x; + // get the input dims + const auto& in_dims = input->dims(); + // calcluate the real axis + if (axis < 0) axis += in_dims.size(); + + int k = k_scalar.to(); + if (k_scalar.FromTensor()) { + phi::DDim out_dims = out->dims(); + out_dims[axis] = k; + out->Resize(out_dims); + indices->Resize(out_dims); + } + + const auto& out_dims = out->dims(); + + const T* input_data = input->data(); + T* output_data = dev_ctx.template Alloc(out); + int64_t* indices_data = dev_ctx.template Alloc(indices); + + if (axis == in_dims.size() - 1) { + // if get the topK from the last axis + const int64_t& input_height = + phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t& input_width = in_dims[in_dims.size() - 1]; + + if (k > input_width) { + k = input_width; + } + + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { + if (ops::SortTopk( + paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), + input, + input_width, + input_height, + k, + out, + indices, + largest)) { + // Successed, return. + return; + } else { + LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " + "default topk kernel."; + } + } + + // NOTE: pass lds and dim same to input width. + // NOTE: old matrix implementation of stride is different to eigen. + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + switch (ops::GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM(ops::KeMatrixTopK< + T, + 20, + kBlockDim><<>>( + output_data, + k, + indices_data, + input_data, + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); +#else + FIXED_BLOCK_DIM(ops::KeMatrixTopK< + T, + 5, + kBlockDim><<>>( + output_data, + k, + indices_data, + input_data, + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); +#endif + default: + PADDLE_THROW(errors::Fatal( + "the input data shape has error in the topk cuda kernel.")); + } + } else { + // if get topK not from the last axis, will tranpose the tensor and get + // TopK + + // first step, prepare the trans args for the tranpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.emplace_back(i); + } + trans.emplace_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.emplace_back(i); + } + trans.emplace_back(axis); + + phi::DDim trans_dims(in_dims); + phi::DDim trans_out_dims(out->dims()); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + trans_out_dims[i] = out_dims[trans[i]]; + } + // second step, tranpose the input + DenseTensor trans_input; + trans_input.Resize(trans_dims); + dev_ctx.template Alloc(&trans_input); + int ndims = trans.size(); + funcs::TransCompute( + ndims, dev_ctx, *input, &trans_input, trans); + // third step, calcluate the topk + // allocate the tmp cuda memory for the tmp result + DenseTensor trans_ind; + DenseTensor trans_out; + trans_ind.Resize(trans_out_dims); + trans_out.Resize(trans_out_dims); + dev_ctx.template Alloc(&trans_ind); + dev_ctx.template Alloc(&trans_out); + + const int64_t input_height = + phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + if (k > input_width) k = input_width; + + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { + if (ops::SortTopk( + paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), + &trans_input, + input_width, + input_height, + k, + &trans_out, + &trans_ind, + largest)) { + // last step, tranpose back the indices and output + funcs::TransCompute( + ndims, dev_ctx, trans_ind, indices, trans); + funcs::TransCompute( + ndims, dev_ctx, trans_out, out, trans); + return; + } else { + LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " + "default topk kernel."; + } + } + + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + switch (ops::GetDesiredBlockDim(input_width)) { +#ifdef PADDLE_WITH_HIP + FIXED_BLOCK_DIM(ops::KeMatrixTopK< + T, + 20, + kBlockDim><<>>( + trans_out.data(), + k, + trans_ind.data(), + trans_input.data(), + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); +#else + FIXED_BLOCK_DIM(ops::KeMatrixTopK< + T, + 5, + kBlockDim><<>>( + trans_out.data(), + k, + trans_ind.data(), + trans_input.data(), + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); +#endif + default: + PADDLE_THROW(errors::Fatal( + "the input data shape has error in the topk cuda kernel.")); + } + + // last step, tranpose back the indices and output + funcs::TransCompute( + ndims, dev_ctx, trans_ind, indices, trans); + funcs::TransCompute( + ndims, dev_ctx, trans_out, out, trans); + } +} +#undef FIXED_BLOCK_DIM_BASE +#undef FIXED_BLOCK_DIM + +} // namespace phi + +PD_REGISTER_KERNEL(top_k, + GPU, + ALL_LAYOUT, + phi::TopkKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/top_k_grad_kernel.h b/paddle/phi/kernels/top_k_grad_kernel.h new file mode 100644 index 00000000000..f577b982c57 --- /dev/null +++ b/paddle/phi/kernels/top_k_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TopkGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& indices, + int k, + int axis, + bool largest, + bool sorted, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/top_k_kernel.h b/paddle/phi/kernels/top_k_kernel.h new file mode 100644 index 00000000000..fea76e448b5 --- /dev/null +++ b/paddle/phi/kernels/top_k_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TopkKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& k_scalar, + int axis, + bool largest, + bool sorted, + DenseTensor* out, + DenseTensor* indices); + +} // namespace phi diff --git a/paddle/phi/ops/compat/top_k_sig.cc b/paddle/phi/ops/compat/top_k_sig.cc new file mode 100644 index 00000000000..9bf922b3d1b --- /dev/null +++ b/paddle/phi/ops/compat/top_k_sig.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("K")) { + return KernelSignature( + "top_k", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"}); + + } else { + return KernelSignature( + "top_k", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"}); + } +} + +KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("top_k_grad", + {GradVarName("Out"), "X", "Indices"}, + {"k", "axis", "largest", "sorted"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, top_k); +PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, top_k_grad); +PD_REGISTER_ARG_MAPPING_FN(top_k_v2, phi::TopkOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(top_k_v2_grad, phi::TopkGradOpArgumentMapping); -- GitLab