From 667082c0adc72770bfb2ecbede8f726af35bc460 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Thu, 29 Sep 2022 11:41:09 +0800 Subject: [PATCH] fix P40 topk: Make the optimized topk compatible with P40. (#46547) * fix P40 topk: Make the optimized topk compatible with P40. * fix P40 topk: Make the optimized topk compatible with P40. * fix P40 topk: Make the optimized topk compatible with P40. --- paddle/fluid/operators/top_k_function_cuda.h | 15 +++++---------- paddle/fluid/operators/top_k_op.cu | 3 ++- paddle/phi/kernels/gpu/top_k_kernel.cu | 6 ++++-- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index a52b461d40e..faf2b080891 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -354,16 +354,11 @@ __device__ __forceinline__ void BlockReduce(Pair shared_max[], } if (--(*k) == 0) break; - if (MaxLength < 5) { - if (*beam >= MaxLength) break; - } else { - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - if (tid_max / 32 == wid) { - if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == - MaxLength) - break; - } + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (tid_max / 32 == wid) { + if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength) + break; } } } diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 86399bfa1b9..f0b92cb0143 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel { input_height)); default: PADDLE_THROW(platform::errors::Fatal( - "the input k has error in the topk cuda kernel.")); + "the input k has error when use getMaxLength function to get the " + "maxLength.")); }); default: PADDLE_THROW(platform::errors::Unavailable( diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 644cc608143..9fc21b19a15 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx, largest)); default: PADDLE_THROW( - errors::Fatal("the input k has error in the topk cuda kernel.")); + errors::Fatal("the input k has error when use getMaxLength " + "function to get the maxLength.")); }); #endif default: @@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx, largest)); default: PADDLE_THROW( - errors::Fatal("the input k has error in the topk cuda kernel.")); + errors::Fatal("the input k has error when use getMaxLength " + "function to get the maxLength.")); }); #endif default: -- GitLab