未验证 提交 41e4f7ea 编写于 作者: Q qingqing01 提交者: GitHub

Optimize Topk when height is large. (#13710)

上级 65ed45a1
......@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) {
const T* src, int lds, int dim, int k,
int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
output += blockIdx.x * output_stride;
indices += blockIdx.x * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride;
indices += i * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int k = 0; k < MaxLength; k++) {
topk[k].set(-INFINITY, -1);
}
while (k) {
ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
for (int k = 0; k < MaxLength; k++) {
topk[k].set(-INFINITY, -1);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
&indices, &beam, &k, tid, warp);
}
}
while (k) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
src + blockIdx.x * lds, &firststep,
&is_empty, &max, dim, tid);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
&indices, &beam, &k, tid, warp);
}
inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): refine this kernel.
dim3 threads(256, 1);
dim3 grid(input_height, 1);
KeMatrixTopK<T, 5, 256><<<
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(
output_data, output->dims()[1], indices_data, input_data, input_width,
input_width, static_cast<int>(k));
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data,
input_width, input_width, static_cast<int>(k), gridx,
input_height));
default:
PADDLE_THROW("Error");
}
}
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册