From 33540e109a742c42b8273e389a1ca1c89596869d Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 7 Jul 2022 17:38:04 +0800 Subject: [PATCH] Fix nan in fast_ln_fwd_kernel when cols > 1024 (#44125) * Fix nan in fast_ln_fwd_kernel when cols > 1024 * delete blas --- .../operators/fused/fused_layernorm_residual_dropout_bias.h | 4 ++-- paddle/fluid/operators/layer_norm_kernel.cu.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 4aedf4eb79b..301b62524a5 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -573,7 +573,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( smem[warp_m * WARPS_N + warp_n] = mu_local; } __syncthreads(); - if (tidx == 0) { + if (tidx % THREADS_PER_ROW == 0) { mu_local = 0.f; #pragma unroll for (int it = 0; it < WARPS_N; ++it) { @@ -608,7 +608,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( smem[warp_m * WARPS_N + warp_n] = var_local; } __syncthreads(); - if (tidx == 0) { + if (tidx % THREADS_PER_ROW == 0) { var_local = 0.f; #pragma unroll for (int it = 0; it < WARPS_N; ++it) { diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index e37f048235e..8ed706a5443 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -252,7 +252,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( smem[warp_m * WARPS_N + warp_n] = mu_local; } __syncthreads(); - if (tidx == 0) { + if (tidx % THREADS_PER_ROW == 0) { mu_local = 0.f; #pragma unroll for (int it = 0; it < WARPS_N; ++it) { @@ -289,7 +289,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( smem[warp_m * WARPS_N + warp_n] = var_local; } __syncthreads(); - if (tidx == 0) { + if (tidx % THREADS_PER_ROW == 0) { var_local = 0.f; #pragma unroll for (int it = 0; it < WARPS_N; ++it) { -- GitLab