未验证 提交 2c687df0 编写于 作者: C carryyu 提交者: GitHub

Optimize topk's performance when k is small and input_width is large (#45312)

* Optimize topk's performance when k is small and input_width is large

* 修改blockdim设置逻辑

* Update top_k_function_cuda.h
上级 18860735
...@@ -27,9 +27,11 @@ limitations under the License. */ ...@@ -27,9 +27,11 @@ limitations under the License. */
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#define FINAL_MASK 0xffffffff
#ifdef __HIPCC__ #ifdef __HIPCC__
namespace rocprim { namespace rocprim {
namespace detail { namespace detail {
...@@ -105,6 +107,14 @@ inline static int GetDesiredBlockDim(int dim) { ...@@ -105,6 +107,14 @@ inline static int GetDesiredBlockDim(int dim) {
} }
} }
inline static int getMaxLength(int k) {
if (k / 5 < 1) {
return 1;
} else if (k / 5 >= 1) {
return min(k / 5, 5);
}
}
template <typename T> template <typename T>
__global__ void InitIndex(T* indices, T num_rows, T num_cols) { __global__ void InitIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x; int col_id = threadIdx.x;
...@@ -248,7 +258,11 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], ...@@ -248,7 +258,11 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
if (k < MaxLength - (*beam)) { if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam]; topk[k] = topk[k + *beam];
} else { } else {
topk[k].set(-static_cast<T>(INFINITY), -1); if (largest) {
topk[k].set(-static_cast<T>(INFINITY), -1);
} else {
topk[k].set(static_cast<T>(INFINITY), -1);
}
} }
} }
if (!(*is_empty)) { if (!(*is_empty)) {
...@@ -258,79 +272,98 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], ...@@ -258,79 +272,98 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
} }
*max = topk[MaxLength - 1]; *max = topk[MaxLength - 1];
if ((*max).v == -static_cast<T>(1)) *is_empty = true; if ((*max).id == -1) *is_empty = true;
*beam = 0; *beam = 0;
} }
} }
template <typename T>
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
const bool& largest) {
if (largest) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v < tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
}
}
} else {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v > tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
}
}
}
return input;
}
template <typename T, int MaxLength, int BlockSize> template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
int* maxid,
Pair<T> topk[], Pair<T> topk[],
T** topVal, T** topVal,
int64_t** topIds, int64_t** topIds,
int* beam, int* beam,
int* k, int* k,
const int tid, const int tid,
const int warp, const int wid,
const int lane,
const bool& largest) { const bool& largest) {
while (true) { while (true) {
__syncthreads(); __syncthreads();
if (tid < BlockSize / 2) { Pair<T> input_now = topk[0];
if (largest) { input_now = WarpReduce(input_now, largest);
if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2; if (lane == 0) {
} else { shared_max[wid] = input_now;
maxid[tid] = tid;
}
} else {
if (sh_topk[tid] > sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2;
} else {
maxid[tid] = tid;
}
}
} }
__syncthreads(); __syncthreads();
for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) { if (largest) {
if (tid < stride) { input_now = (tid < BlockSize / 32)
if (largest) { ? shared_max[lane]
if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) { : Pair<T>(-static_cast<T>(INFINITY), -1);
maxid[tid] = maxid[tid + stride]; } else {
} input_now = (tid < BlockSize / 32)
} else { ? shared_max[lane]
if (sh_topk[maxid[tid]] > sh_topk[maxid[tid + stride]]) { : Pair<T>(static_cast<T>(INFINITY), -1);
maxid[tid] = maxid[tid + stride]; }
} if (wid == 0) {
} input_now = WarpReduce(input_now, largest);
} if (lane == 0) shared_max[0] = input_now;
__syncthreads();
} }
__syncthreads(); __syncthreads();
if (tid == 0) { if (tid == 0) {
**topVal = sh_topk[maxid[0]].v; **topVal = input_now.v;
**topIds = sh_topk[maxid[0]].id; **topIds = input_now.id;
(*topVal)++; (*topVal)++;
(*topIds)++; (*topIds)++;
} }
if (tid == maxid[0]) (*beam)++; int tid_max = shared_max[0].id % BlockSize;
if (--(*k) == 0) break; if (tid == tid_max) {
__syncthreads(); (*beam)++;
if (tid == maxid[0]) {
if (*beam < MaxLength) { if (*beam < MaxLength) {
sh_topk[tid] = topk[*beam]; topk[0] = topk[*beam];
} }
} }
// NOTE(zcd): temporary solution if (--(*k) == 0) break;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true); if (MaxLength < 5) {
if (*beam >= MaxLength) break;
if (maxid[0] / 32 == warp) { } else {
if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) == unsigned mask = 0u;
MaxLength) CREATE_SHFL_MASK(mask, true);
break; if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
} }
} }
} }
...@@ -355,14 +388,13 @@ __global__ void KeMatrixTopK(T* output, ...@@ -355,14 +388,13 @@ __global__ void KeMatrixTopK(T* output,
int grid_dim, int grid_dim,
int num, int num,
bool largest = true) { bool largest = true) {
__shared__ Pair<T> sh_topk[BlockSize];
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int warp = threadIdx.x / 32; const int wid = tid / 32;
const int lane = tid % 32;
const int bid = blockIdx.x; const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) { for (int i = bid; i < num; i += grid_dim) {
int top_num = k; int top_num = k;
__shared__ int maxid[BlockSize / 2]; __shared__ Pair<T> shared_max[BlockSize / 32];
T* out = output + i * output_stride; T* out = output + i * output_stride;
int64_t* inds = indices + i * k; int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength]; Pair<T> topk[MaxLength];
...@@ -389,17 +421,15 @@ __global__ void KeMatrixTopK(T* output, ...@@ -389,17 +421,15 @@ __global__ void KeMatrixTopK(T* output,
dim, dim,
tid, tid,
largest); largest);
BlockReduce<T, MaxLength, BlockSize>(shared_max,
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk,
maxid,
topk, topk,
&out, &out,
&inds, &inds,
&beam, &beam,
&top_num, &top_num,
tid, tid,
warp, wid,
lane,
largest); largest);
} }
} }
......
...@@ -38,12 +38,27 @@ using Tensor = framework::Tensor; ...@@ -38,12 +38,27 @@ using Tensor = framework::Tensor;
__VA_ARGS__; \ __VA_ARGS__; \
} break } break
#define FIXED_BLOCK_DIM(...) \ #define FIXED_MAXLENGTH_BASE(MaxLength, ...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ case (MaxLength): { \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ constexpr auto maxLength = (MaxLength); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ __VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
#define FIXED_MAXLENGTH(...) \
FIXED_MAXLENGTH_BASE(1, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(2, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(3, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(4, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(5, ##__VA_ARGS__)
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> { class TopkOpCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -95,18 +110,25 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -95,18 +110,25 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// TODO(typhoonzero): refine this kernel. // TODO(typhoonzero): refine this kernel.
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) { paddle::platform::GpuLaunchConfig config =
FIXED_BLOCK_DIM( paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
KeMatrixTopK<T, 5, kBlockDim> switch (config.thread_per_block.x) {
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data, FIXED_BLOCK_DIM(switch (getMaxLength(k)) {
k, FIXED_MAXLENGTH(
indices_data, KeMatrixTopK<T, maxLength, kBlockDim>
input_data, <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
input_width, k,
input_width, indices_data,
static_cast<int>(k), input_data,
gridx, input_width,
input_height)); input_width,
static_cast<int>(k),
gridx,
input_height));
default:
PADDLE_THROW(platform::errors::Fatal(
"the input k has error in the topk cuda kernel."));
});
default: default:
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Calculation error occurred in TopK Operator's CUDA Kernel.")); "Calculation error occurred in TopK Operator's CUDA Kernel."));
......
...@@ -31,12 +31,27 @@ namespace ops = paddle::operators; ...@@ -31,12 +31,27 @@ namespace ops = paddle::operators;
__VA_ARGS__; \ __VA_ARGS__; \
} break } break
#define FIXED_BLOCK_DIM(...) \ #define FIXED_MAXLENGTH_BASE(MaxLength, ...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ case (MaxLength): { \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ constexpr auto maxLength = (MaxLength); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ __VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
#define FIXED_MAXLENGTH(...) \
FIXED_MAXLENGTH_BASE(1, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(2, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(3, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(4, ##__VA_ARGS__); \
FIXED_MAXLENGTH_BASE(5, ##__VA_ARGS__)
template <typename T, typename Context> template <typename T, typename Context>
void TopkKernel(const Context& dev_ctx, void TopkKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -158,7 +173,9 @@ void TopkKernel(const Context& dev_ctx, ...@@ -158,7 +173,9 @@ void TopkKernel(const Context& dev_ctx,
// NOTE: old matrix implementation of stride is different to eigen. // NOTE: old matrix implementation of stride is different to eigen.
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (ops::GetDesiredBlockDim(input_width)) { paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
switch (config.thread_per_block.x) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
ops::KeMatrixTopK<T, 20, kBlockDim> ops::KeMatrixTopK<T, 20, kBlockDim>
...@@ -173,18 +190,23 @@ void TopkKernel(const Context& dev_ctx, ...@@ -173,18 +190,23 @@ void TopkKernel(const Context& dev_ctx,
input_height, input_height,
largest)); largest));
#else #else
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
ops::KeMatrixTopK<T, 5, kBlockDim> FIXED_MAXLENGTH(
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data, ops::KeMatrixTopK<T, maxLength, kBlockDim>
k, <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
indices_data, k,
input_data, indices_data,
input_width, input_data,
input_width, input_width,
static_cast<int>(k), input_width,
gridx, static_cast<int>(k),
input_height, gridx,
largest)); input_height,
largest));
default:
PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel."));
});
#endif #endif
default: default:
PADDLE_THROW(errors::Fatal( PADDLE_THROW(errors::Fatal(
...@@ -259,7 +281,9 @@ void TopkKernel(const Context& dev_ctx, ...@@ -259,7 +281,9 @@ void TopkKernel(const Context& dev_ctx,
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (ops::GetDesiredBlockDim(input_width)) { paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
switch (config.thread_per_block.x) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
ops::KeMatrixTopK<T, 20, kBlockDim> ops::KeMatrixTopK<T, 20, kBlockDim>
...@@ -274,18 +298,23 @@ void TopkKernel(const Context& dev_ctx, ...@@ -274,18 +298,23 @@ void TopkKernel(const Context& dev_ctx,
input_height, input_height,
largest)); largest));
#else #else
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
ops::KeMatrixTopK<T, 5, kBlockDim> FIXED_MAXLENGTH(ops::KeMatrixTopK<T, maxLength, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(trans_out.data<T>(), <<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
k, trans_out.data<T>(),
trans_ind.data<int64_t>(), k,
trans_input.data<T>(), trans_ind.data<int64_t>(),
input_width, trans_input.data<T>(),
input_width, input_width,
static_cast<int>(k), input_width,
gridx, static_cast<int>(k),
input_height, gridx,
largest)); input_height,
largest));
default:
PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel."));
});
#endif #endif
default: default:
PADDLE_THROW(errors::Fatal( PADDLE_THROW(errors::Fatal(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册