diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index adaf5cc092b4e97653e5a022950062aa7d68cd9f..8262023826b328c2250e9c95c70b72b7eb4b212b 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -98,7 +98,7 @@ void TopkKernel(const Context& dev_ctx, } #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 - if (input_width >= 1024 && input_height == 1) { + if (input_width >= 1024 && in_dims.size() == 1) { // 1. Gather TopK, but without sorting constexpr int max_num_threads = 1024; if (largest) {