未验证 提交 c6de4342 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix token prune plugin (#48367)

* fix
上级 c2f07f5b
......@@ -36,10 +36,12 @@ __global__ void ElementwiseMask(const T* a,
const T* b,
T* res,
int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return;
const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
}
template <typename T>
......@@ -119,6 +121,7 @@ __global__ void ReduceSum2(
template <>
__global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
......@@ -150,6 +153,7 @@ __global__ void ReduceSum2<half>(
static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0]));
}
#endif
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册