From efeeb6fb4b65b17a7d5e53d377ffda728502f072 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Mon, 24 Apr 2023 20:32:41 +0800 Subject: [PATCH] Fix the calculation of layer_norm_bwd (#53230) --- paddle/fluid/operators/layer_norm_kernel.cu.h | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 899eae3efb4..17c6adc171c 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -1397,13 +1397,13 @@ __global__ void LayerNormBackwardGradientAll( for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; - auto var_val = real_sqrt(static_cast(var[row_idx]) + epsilon); + auto var_val = rsqrt_(static_cast(var[row_idx]) + epsilon); d_scale_partial += static_cast(d_y[i]) * - (static_cast(x[i]) - mean[row_idx]) / var_val; + (static_cast(x[i]) - mean[row_idx]) * var_val; d_bias_partial += static_cast(d_y[i]); if (HasDx) { d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[blockIdx.x + col_offset]) / + static_cast(scale[blockIdx.x + col_offset]) * var_val); } } @@ -1453,10 +1453,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = - static_cast(real_sqrt(static_cast(var[row_idx]) + epsilon)); + static_cast(rsqrt_(static_cast(var[row_idx]) + epsilon)); if (HasDScale) { d_scale_or_d_bias_partial += static_cast(d_y[i]) * - (static_cast(x[i]) - mean[row_idx]) / + (static_cast(x[i]) - mean[row_idx]) * var_val; } else { // d_bias != nullptr d_scale_or_d_bias_partial += static_cast(d_y[i]); @@ -1465,10 +1465,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (HasDx) { if (scale != nullptr) { d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[blockIdx.x + col_offset]) / + static_cast(scale[blockIdx.x + col_offset]) * var_val); } else { - d_x[i] = static_cast(static_cast(d_y[i]) / var_val); + d_x[i] = static_cast(static_cast(d_y[i]) * var_val); } } } @@ -1556,13 +1556,13 @@ __global__ void LayerNormBackwardGradientOnlyDX( U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { auto var_val = - static_cast(real_sqrt(static_cast(block_var) + epsilon)); + static_cast(rsqrt_(static_cast(block_var) + epsilon)); if (scale != nullptr) { int col_idx = i % feature_size; d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[col_idx]) / var_val); + static_cast(scale[col_idx]) * var_val); } else { - d_x[i] = static_cast(static_cast(d_y[i]) / var_val); + d_x[i] = static_cast(static_cast(d_y[i]) * var_val); } d_x_mean_partial += static_cast(d_x[i]); d_x_var_partial += @@ -1606,21 +1606,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; using ScaleBiasT = LayerNormScaleBiasT; if (idx < feature_size) { - auto var_val = - static_cast(real_sqrt(static_cast(var[0]) + epsilon)); + auto var_val = static_cast(rsqrt_(static_cast(var[0]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { - d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); + d_x[idx] = static_cast(static_cast(d_y[idx]) * var_val); } else { d_x[idx] = static_cast(static_cast(d_y[idx]) * - static_cast(scale[idx]) / var_val); + static_cast(scale[idx]) * var_val); } } if (d_scale != nullptr) { d_scale[idx] = static_cast(static_cast(d_y[idx]) * - (static_cast(x[idx]) - mean[0]) / var_val); + (static_cast(x[idx]) - mean[0]) * var_val); } if (d_bias != nullptr) { -- GitLab