From 89fb1964304ced407e9086d3ae2da3a5a3518d5e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 25 Nov 2021 11:22:45 +0800 Subject: [PATCH] [cherry-pick-2.2.1]Opt topk (#37325) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。 --- paddle/fluid/operators/top_k_v2_op.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 6e74ca46d2c..1ba6846dce5 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { if (k > input_width) k = input_width; - if ((input_width <= 1024 || k >= 128 || k == input_width)) { + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { if (SortTopk(dev_ctx, input, input_width, input_height, k, output, indices, largest)) { // Successed, return. @@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { if (k > input_width) k = input_width; - if (((input_width <= 1024 && input_height <= 2048) || k >= 128 || - k == input_width)) { + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { if (SortTopk(dev_ctx, &trans_input, input_width, input_height, k, &trans_out, &trans_ind, largest)) { // last step, tranpose back the indices and output -- GitLab