未验证 提交 c4862d99 编写于 作者: Z zhangkaihuo 提交者: GitHub

Opt topk (#37256)

topk中有cub和手写kernel两种实现,而cub是通过排序来获取topk,通过多组数据发现只有当input_width>=128且k超过input_width 75%的时候性能会比手写的更好。
上级 162ac048
......@@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
if (k > input_width) k = input_width;
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices, largest)) {
// Successed, return.
......@@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
if (k > input_width) k = input_width;
if (((input_width <= 1024 && input_height <= 2048) || k >= 128 ||
k == input_width)) {
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind, largest)) {
// last step, tranpose back the indices and output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册