未验证 提交 36154ba9 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] fix_qkv_plugin: fix half scale (#37096)

* fix_qkv_plugin: half_scale

* [Paddle-Inference] fix_qkv_plugin: fix half scale
上级 9574bcd7
...@@ -229,7 +229,9 @@ template <typename T> ...@@ -229,7 +229,9 @@ template <typename T>
__global__ void apply_scale(T *data, T scale, int n) { __global__ void apply_scale(T *data, T scale, int n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
data[tid] = data[tid] * scale; data[tid] = data[tid] * scale;
}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册