diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 6e74ca46d2cd2796346b7dc2acbea355058826e2..1ba6846dce52b28d5c277b7036902ad34ee19644 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { if (k > input_width) k = input_width; - if ((input_width <= 1024 || k >= 128 || k == input_width)) { + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { if (SortTopk(dev_ctx, input, input_width, input_height, k, output, indices, largest)) { // Successed, return. @@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel { if (k > input_width) k = input_width; - if (((input_width <= 1024 && input_height <= 2048) || k >= 128 || - k == input_width)) { + // The conclusion is drawn from the data through multiple sets of + // statistics + if (input_width >= 128 && k >= input_width * 0.75) { if (SortTopk(dev_ctx, &trans_input, input_width, input_height, k, &trans_out, &trans_ind, largest)) { // last step, tranpose back the indices and output