diff --git a/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu index 19128256bcb671f8f84f4296ba07a3b8951d4624..ca59d4e9daeee3dba13243dda5a23dc6564c05b2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu @@ -92,8 +92,12 @@ __global__ void layernorm_shift_partition(T *out, float mean = 0.0f; float variance = 0.0f; +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) float local_out = (tid < n) ? static_cast(__ldg(input + bid * n + tid)) : 0.0f; +#else + float local_out = (tid < n) ? static_cast(input[bid * n + tid]) : 0.0f; +#endif mean = blockReduceSum(local_out); if (threadIdx.x == 0) { @@ -109,14 +113,20 @@ __global__ void layernorm_shift_partition(T *out, __syncthreads(); if (tid < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) out[output_bid * n + tid] = (T)(((local_out - s_mean) * rsqrtf(s_variance)) * static_cast(__ldg(&gamma[tid])) + static_cast(__ldg(&beta[tid]))); +#else + out[output_bid * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * + static_cast(gamma[tid]) + + static_cast(beta[tid])); +#endif } } -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> __global__ void layernorm_shift_partition(half2 *out_ptr, const half2 *input_ptr, @@ -129,6 +139,7 @@ __global__ void layernorm_shift_partition(half2 *out_ptr, int shift_size, int window_size, const float eps) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; const int shifted_H_idx = @@ -185,8 +196,8 @@ __global__ void layernorm_shift_partition(half2 *out_ptr, (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; out_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2); } -} #endif +} #define kITE 4 template @@ -233,7 +244,11 @@ __global__ void layernorm_shift_partition_v2(T *out, for (int i = 0; i < kITE; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) local_out[i] = static_cast(__ldg(input + offset + col_id)); +#else + local_out[i] = static_cast(input[offset + col_id]); +#endif sum += local_out[i]; } } @@ -265,15 +280,20 @@ __global__ void layernorm_shift_partition_v2(T *out, for (int i = 0; i < kITE; i++) { int col_id = i * blockDim.x + tid; if (col_id < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) out[output_offset + col_id] = (T)(local_out[i] * s_variance * static_cast(__ldg(&gamma[col_id])) + static_cast(__ldg(&beta[col_id]))); +#else + out[output_offset + col_id] = + (T)(local_out[i] * s_variance * static_cast(gamma[col_id]) + + static_cast(beta[col_id])); +#endif } } } -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> __global__ void layernorm_shift_partition_v2(half2 *out_ptr, const half2 *__restrict input_ptr, @@ -286,6 +306,7 @@ __global__ void layernorm_shift_partition_v2(half2 *out_ptr, int shift_size, int window_size, const float eps) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) // constexpr int ite = 4; const int tid = threadIdx.x; const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; @@ -359,8 +380,8 @@ __global__ void layernorm_shift_partition_v2(half2 *out_ptr, __ldg(&beta_ptr[col_id]); } } -} #endif +} template void invokeLayernormShiftPartition(T *out,