未验证 提交 c6de4342 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix token prune plugin (#48367)

* fix
上级 c2f07f5b
...@@ -36,10 +36,12 @@ __global__ void ElementwiseMask(const T* a, ...@@ -36,10 +36,12 @@ __global__ void ElementwiseMask(const T* a,
const T* b, const T* b,
T* res, T* res,
int num_elements) { int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return; if (tid >= num_elements) return;
const T zero = 0; const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero; res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
} }
template <typename T> template <typename T>
...@@ -119,6 +121,7 @@ __global__ void ReduceSum2( ...@@ -119,6 +121,7 @@ __global__ void ReduceSum2(
template <> template <>
__global__ void ReduceSum2<half>( __global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
...@@ -150,6 +153,7 @@ __global__ void ReduceSum2<half>( ...@@ -150,6 +153,7 @@ __global__ void ReduceSum2<half>(
static_cast<size_t>(bsz * max_seq_len), static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0])); static_cast<platform::float16>(res_half[0]));
} }
#endif
} }
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册