From 2c687df042f456b2b16cb3d8519e9e54fdb68503 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Thu, 22 Sep 2022 10:57:22 +0800 Subject: [PATCH] Optimize topk's performance when k is small and input_width is large (#45312) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Optimize topk's performance when k is small and input_width is large * 修改blockdim设置逻辑 * Update top_k_function_cuda.h --- paddle/fluid/operators/top_k_function_cuda.h | 144 +++++++++++-------- paddle/fluid/operators/top_k_op.cu | 54 ++++--- paddle/phi/kernels/gpu/top_k_kernel.cu | 89 ++++++++---- 3 files changed, 184 insertions(+), 103 deletions(-) diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index 4a038c93a1..40ccbc4a84 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -27,9 +27,11 @@ limitations under the License. */ #include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" +#define FINAL_MASK 0xffffffff #ifdef __HIPCC__ namespace rocprim { namespace detail { @@ -105,6 +107,14 @@ inline static int GetDesiredBlockDim(int dim) { } } +inline static int getMaxLength(int k) { + if (k / 5 < 1) { + return 1; + } else if (k / 5 >= 1) { + return min(k / 5, 5); + } +} + template __global__ void InitIndex(T* indices, T num_rows, T num_cols) { int col_id = threadIdx.x; @@ -248,7 +258,11 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-static_cast(INFINITY), -1); + if (largest) { + topk[k].set(-static_cast(INFINITY), -1); + } else { + topk[k].set(static_cast(INFINITY), -1); + } } } if (!(*is_empty)) { @@ -258,79 +272,98 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], } *max = topk[MaxLength - 1]; - if ((*max).v == -static_cast(1)) *is_empty = true; + if ((*max).id == -1) *is_empty = true; *beam = 0; } } +template +__forceinline__ __device__ Pair WarpReduce(Pair input, + const bool& largest) { + if (largest) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset); + int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset); + if (input.v < tmp_val || (input.v == tmp_val && input.id > tmp_id)) { + input.v = tmp_val; + input.id = tmp_id; + } + } + } else { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset); + int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset); + if (input.v > tmp_val || (input.v == tmp_val && input.id > tmp_id)) { + input.v = tmp_val; + input.id = tmp_id; + } + } + } + return input; +} + template -__device__ __forceinline__ void BlockReduce(Pair* sh_topk, - int* maxid, +__device__ __forceinline__ void BlockReduce(Pair shared_max[], Pair topk[], T** topVal, int64_t** topIds, int* beam, int* k, const int tid, - const int warp, + const int wid, + const int lane, const bool& largest) { while (true) { __syncthreads(); - if (tid < BlockSize / 2) { - if (largest) { - if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) { - maxid[tid] = tid + BlockSize / 2; - } else { - maxid[tid] = tid; - } - } else { - if (sh_topk[tid] > sh_topk[tid + BlockSize / 2]) { - maxid[tid] = tid + BlockSize / 2; - } else { - maxid[tid] = tid; - } - } + Pair input_now = topk[0]; + input_now = WarpReduce(input_now, largest); + + if (lane == 0) { + shared_max[wid] = input_now; } __syncthreads(); - for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) { - if (tid < stride) { - if (largest) { - if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) { - maxid[tid] = maxid[tid + stride]; - } - } else { - if (sh_topk[maxid[tid]] > sh_topk[maxid[tid + stride]]) { - maxid[tid] = maxid[tid + stride]; - } - } - } - __syncthreads(); + if (largest) { + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(-static_cast(INFINITY), -1); + } else { + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(static_cast(INFINITY), -1); + } + if (wid == 0) { + input_now = WarpReduce(input_now, largest); + if (lane == 0) shared_max[0] = input_now; } __syncthreads(); if (tid == 0) { - **topVal = sh_topk[maxid[0]].v; - **topIds = sh_topk[maxid[0]].id; + **topVal = input_now.v; + **topIds = input_now.id; (*topVal)++; (*topIds)++; } - if (tid == maxid[0]) (*beam)++; - if (--(*k) == 0) break; - __syncthreads(); - - if (tid == maxid[0]) { + int tid_max = shared_max[0].id % BlockSize; + if (tid == tid_max) { + (*beam)++; if (*beam < MaxLength) { - sh_topk[tid] = topk[*beam]; + topk[0] = topk[*beam]; } } - // NOTE(zcd): temporary solution - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - - if (maxid[0] / 32 == warp) { - if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) == - MaxLength) - break; + if (--(*k) == 0) break; + + if (MaxLength < 5) { + if (*beam >= MaxLength) break; + } else { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (tid_max / 32 == wid) { + if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == + MaxLength) + break; + } } } } @@ -355,14 +388,13 @@ __global__ void KeMatrixTopK(T* output, int grid_dim, int num, bool largest = true) { - __shared__ Pair sh_topk[BlockSize]; const int tid = threadIdx.x; - const int warp = threadIdx.x / 32; - + const int wid = tid / 32; + const int lane = tid % 32; const int bid = blockIdx.x; for (int i = bid; i < num; i += grid_dim) { int top_num = k; - __shared__ int maxid[BlockSize / 2]; + __shared__ Pair shared_max[BlockSize / 32]; T* out = output + i * output_stride; int64_t* inds = indices + i * k; Pair topk[MaxLength]; @@ -389,17 +421,15 @@ __global__ void KeMatrixTopK(T* output, dim, tid, largest); - - sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, - maxid, + BlockReduce(shared_max, topk, &out, &inds, &beam, &top_num, tid, - warp, + wid, + lane, largest); } } diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 79236f590f..dc58c1cea5 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -38,12 +38,27 @@ using Tensor = framework::Tensor; __VA_ARGS__; \ } break -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ +#define FIXED_MAXLENGTH_BASE(MaxLength, ...) \ + case (MaxLength): { \ + constexpr auto maxLength = (MaxLength); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) +#define FIXED_MAXLENGTH(...) \ + FIXED_MAXLENGTH_BASE(1, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(2, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(3, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(4, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(5, ##__VA_ARGS__) + template class TopkOpCUDAKernel : public framework::OpKernel { public: @@ -95,18 +110,25 @@ class TopkOpCUDAKernel : public framework::OpKernel { // TODO(typhoonzero): refine this kernel. const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - switch (GetDesiredBlockDim(input_width)) { - FIXED_BLOCK_DIM( - KeMatrixTopK - <<>>(output_data, - k, - indices_data, - input_data, - input_width, - input_width, - static_cast(k), - gridx, - input_height)); + paddle::platform::GpuLaunchConfig config = + paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width); + switch (config.thread_per_block.x) { + FIXED_BLOCK_DIM(switch (getMaxLength(k)) { + FIXED_MAXLENGTH( + KeMatrixTopK + <<>>(output_data, + k, + indices_data, + input_data, + input_width, + input_width, + static_cast(k), + gridx, + input_height)); + default: + PADDLE_THROW(platform::errors::Fatal( + "the input k has error in the topk cuda kernel.")); + }); default: PADDLE_THROW(platform::errors::Unavailable( "Calculation error occurred in TopK Operator's CUDA Kernel.")); diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 657430e1e7..644cc60814 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -31,12 +31,27 @@ namespace ops = paddle::operators; __VA_ARGS__; \ } break -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ +#define FIXED_MAXLENGTH_BASE(MaxLength, ...) \ + case (MaxLength): { \ + constexpr auto maxLength = (MaxLength); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) +#define FIXED_MAXLENGTH(...) \ + FIXED_MAXLENGTH_BASE(1, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(2, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(3, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(4, ##__VA_ARGS__); \ + FIXED_MAXLENGTH_BASE(5, ##__VA_ARGS__) + template void TopkKernel(const Context& dev_ctx, const DenseTensor& x, @@ -158,7 +173,9 @@ void TopkKernel(const Context& dev_ctx, // NOTE: old matrix implementation of stride is different to eigen. const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - switch (ops::GetDesiredBlockDim(input_width)) { + paddle::platform::GpuLaunchConfig config = + paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width); + switch (config.thread_per_block.x) { #ifdef PADDLE_WITH_HIP FIXED_BLOCK_DIM( ops::KeMatrixTopK @@ -173,18 +190,23 @@ void TopkKernel(const Context& dev_ctx, input_height, largest)); #else - FIXED_BLOCK_DIM( - ops::KeMatrixTopK - <<>>(output_data, - k, - indices_data, - input_data, - input_width, - input_width, - static_cast(k), - gridx, - input_height, - largest)); + FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) { + FIXED_MAXLENGTH( + ops::KeMatrixTopK + <<>>(output_data, + k, + indices_data, + input_data, + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); + default: + PADDLE_THROW( + errors::Fatal("the input k has error in the topk cuda kernel.")); + }); #endif default: PADDLE_THROW(errors::Fatal( @@ -259,7 +281,9 @@ void TopkKernel(const Context& dev_ctx, const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - switch (ops::GetDesiredBlockDim(input_width)) { + paddle::platform::GpuLaunchConfig config = + paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width); + switch (config.thread_per_block.x) { #ifdef PADDLE_WITH_HIP FIXED_BLOCK_DIM( ops::KeMatrixTopK @@ -274,18 +298,23 @@ void TopkKernel(const Context& dev_ctx, input_height, largest)); #else - FIXED_BLOCK_DIM( - ops::KeMatrixTopK - <<>>(trans_out.data(), - k, - trans_ind.data(), - trans_input.data(), - input_width, - input_width, - static_cast(k), - gridx, - input_height, - largest)); + FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) { + FIXED_MAXLENGTH(ops::KeMatrixTopK + <<>>( + trans_out.data(), + k, + trans_ind.data(), + trans_input.data(), + input_width, + input_width, + static_cast(k), + gridx, + input_height, + largest)); + default: + PADDLE_THROW( + errors::Fatal("the input k has error in the topk cuda kernel.")); + }); #endif default: PADDLE_THROW(errors::Fatal( -- GitLab