diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index a52b461d40ed7a622ec114639d03e785a3bd5156..faf2b08089157d4e2f4127175cdf27f4391876c8 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -354,16 +354,11 @@ __device__ __forceinline__ void BlockReduce(Pair shared_max[], } if (--(*k) == 0) break; - if (MaxLength < 5) { - if (*beam >= MaxLength) break; - } else { - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - if (tid_max / 32 == wid) { - if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == - MaxLength) - break; - } + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (tid_max / 32 == wid) { + if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength) + break; } } } diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 86399bfa1b91d38a7ce03896f84547cf0dbe8d91..f0b92cb01435ab84d69f05575d9b9b8d39db46ad 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel { input_height)); default: PADDLE_THROW(platform::errors::Fatal( - "the input k has error in the topk cuda kernel.")); + "the input k has error when use getMaxLength function to get the " + "maxLength.")); }); default: PADDLE_THROW(platform::errors::Unavailable( diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 644cc6081436ac67b63ccba1b45a68ca2a070ec2..9fc21b19a156c2587df02b3fbcde47a26fe59383 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx, largest)); default: PADDLE_THROW( - errors::Fatal("the input k has error in the topk cuda kernel.")); + errors::Fatal("the input k has error when use getMaxLength " + "function to get the maxLength.")); }); #endif default: @@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx, largest)); default: PADDLE_THROW( - errors::Fatal("the input k has error in the topk cuda kernel.")); + errors::Fatal("the input k has error when use getMaxLength " + "function to get the maxLength.")); }); #endif default: