From c6de43428076ce8471c9a962796227ca843349db Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 25 Nov 2022 17:52:59 +0800 Subject: [PATCH] [Paddle Inference]fix token prune plugin (#48367) * fix --- .../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu index b0c800d31bf..97a60b37088 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu @@ -36,10 +36,12 @@ __global__ void ElementwiseMask(const T* a, const T* b, T* res, int num_elements) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= num_elements) return; const T zero = 0; res[tid] = b[tid] >= zero ? a[tid] : zero; +#endif } template @@ -119,6 +121,7 @@ __global__ void ReduceSum2( template <> __global__ void ReduceSum2( const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int tid = threadIdx.x; int bid = blockIdx.x; int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); @@ -150,6 +153,7 @@ __global__ void ReduceSum2( static_cast(bsz * max_seq_len), static_cast(res_half[0])); } +#endif } template -- GitLab