diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h index f04c7a8da8be1e638aeaa73f8a795e2ea0ecae19..de58c05149a53dbe27f1f389b889e9af3117f8d3 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h +++ b/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -32,6 +32,17 @@ limitations under the License. */ #include "paddle/phi/kernels/primitive/functor_primitives.h" #define FINAL_MASK 0xffffffff +#define WARP_SIZE 32 +#define MAX_NUM_THREADS 1024 + +inline static size_t divide_round_up(size_t n, size_t q) { + return n % q == 0 ? n / q : n / q + 1; +} + +inline static size_t round_up(size_t n, size_t q) { + return divide_round_up(n, q) * q; +} + #ifdef __HIPCC__ namespace rocprim { namespace detail { @@ -808,6 +819,61 @@ __device__ void RadixSearch( *kth_value = RadixTypeConfig::Deconvert(desired); } +template +__global__ void GatherKthValue(const T* input, + const int k, + const int64_t num_rows, + const int64_t num_cols, + T* output, + int64_t* indices) { + __shared__ int shared_mem[32]; + int row = + blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x; + const T* cur_input = input + row * num_cols; + + // 1. Find the k-th value + T kth_value = static_cast(0); + RadixSearch::RadixType, false>( + cur_input, k, num_cols, shared_mem, &kth_value); + const auto converted_kth_value = RadixTypeConfig::Convert(kth_value); + + // 2. find the k-th index + int64_t kth_index = 0; + bool foundKValue = false; + for (int64_t i = threadIdx.x; i < num_cols; i += blockDim.x) { + bool inRange = (i < num_cols); + T v = inRange ? cur_input[i] : static_cast(0); + bool isKValue = + inRange && ((v == kth_value) || (isnan(static_cast(v)) && + isnan(static_cast(kth_value)))); + if (isKValue) { + kth_index = i; + foundKValue = true; + break; + } + } + + if (foundKValue) { + output[row] = kth_value; + indices[row] = kth_index; + } +} + +template +void LaunchGatherKthValue(const phi::GPUContext& dev_ctx, + const T* input_data, + const int64_t num_cols, + const int64_t num_rows, + const int k, + T* out_data, + int64_t* indices_data) { + int num_threads = std::min( + static_cast(round_up(static_cast(num_cols), WARP_SIZE)), + MAX_NUM_THREADS); + GatherKthValue<<>>( + input_data, k, num_rows, num_cols, out_data, indices_data); +} + template __global__ void RadixTopK(const T* input, int k, diff --git a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu index 69c65aa83957ae3bc617cff45b4b8c7f34f58825..599beb7a07a787cac3c1604ad54bfe71d3e8e656 100644 --- a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu @@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/kthvalue_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_kernel.cu index 3af11ee5560e9de62f7947c45f19c6f6ac0f3a03..235abdbc803c3982210be7cd30d8831cfa349c15 100644 --- a/paddle/phi/kernels/gpu/kthvalue_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_kernel.cu @@ -163,7 +163,6 @@ void KthvalueKernel(const Context& dev_ctx, const auto& in_dims = x.dims(); if (axis < 0) axis += in_dims.size(); auto out_dims = output->dims(); - const T* input_data = x.data(); T* output_data = dev_ctx.template Alloc(output); int64_t* indices_data = dev_ctx.template Alloc(indices); @@ -180,15 +179,28 @@ void KthvalueKernel(const Context& dev_ctx, phi::funcs::set_constant(dev_ctx, indices, 0); return; } + if (axis == in_dims.size() - 1) { const int64_t& input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t& input_width = in_dims[in_dims.size() - 1]; +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 + const T* input_data = x.data(); + funcs::LaunchGatherKthValue(dev_ctx, + input_data, + input_width, + input_height, + k, + output_data, + indices_data); +#else PADDLE_ENFORCE_EQ( SortKthvalue( dev_ctx, &x, input_width, input_height, k, output, indices), true, phi::errors::External("KthvalueOP: Error when use cub sorting")); +#endif + return; } else { std::vector trans; @@ -222,18 +234,28 @@ void KthvalueKernel(const Context& dev_ctx, trans_out_dims[in_dims.size() - 1] = 1; DenseTensor trans_input; trans_input.Resize(trans_dims); - dev_ctx.template Alloc(&trans_input); + T* tran_input_data = dev_ctx.template Alloc(&trans_input); int ndims = trans.size(); funcs::TransCompute( ndims, dev_ctx, x, &trans_input, trans); DenseTensor trans_ind, trans_out; trans_ind.Resize(trans_out_dims); trans_out.Resize(trans_out_dims); - dev_ctx.template Alloc(&trans_ind); - dev_ctx.template Alloc(&trans_out); + int64_t* tran_indices_data = dev_ctx.template Alloc(&trans_ind); + T* tran_output_data = dev_ctx.template Alloc(&trans_out); const int64_t input_height = phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); const int64_t input_width = trans_dims[trans_dims.size() - 1]; + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 + funcs::LaunchGatherKthValue(dev_ctx, + tran_input_data, + input_width, + input_height, + k, + tran_output_data, + tran_indices_data); +#else PADDLE_ENFORCE_EQ( SortKthvalue(dev_ctx, &trans_input, @@ -244,6 +266,7 @@ void KthvalueKernel(const Context& dev_ctx, &trans_ind), true, phi::errors::External("KthvalueOP: Error when use cub sorting")); +#endif funcs::TransCompute( ndims, dev_ctx, trans_ind, indices, trans); funcs::TransCompute( @@ -263,6 +286,7 @@ PD_REGISTER_KERNEL(kthvalue, float, double, int, - int64_t) { + int64_t, + phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); }