From a0aff194a181de8d7b31932bbfcaee11ce7fcd8b Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Mon, 24 Apr 2023 17:01:07 +0800 Subject: [PATCH] Fix the calculation of layer_norm_bwd (#53224) * Fix the calculation of layer_norm_bwd * fix --- paddle/phi/kernels/funcs/layer_norm_impl.cu.h | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index d8ade4612c8..b240be28ec9 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll( for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; - auto var_val = real_sqrt(static_cast(var[row_idx]) + epsilon); + auto var_val = rsqrt_(static_cast(var[row_idx]) + epsilon); d_scale_partial += static_cast(d_y[i]) * - (static_cast(x[i]) - mean[row_idx]) / var_val; + (static_cast(x[i]) - mean[row_idx]) * var_val; d_bias_partial += static_cast(d_y[i]); if (HasDx) { d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[blockIdx.x + col_offset]) / + static_cast(scale[blockIdx.x + col_offset]) * var_val); } } @@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = - static_cast(real_sqrt(static_cast(var[row_idx]) + epsilon)); + static_cast(rsqrt_(static_cast(var[row_idx]) + epsilon)); if (HasDScale) { d_scale_or_d_bias_partial += static_cast(d_y[i]) * - (static_cast(x[i]) - mean[row_idx]) / + (static_cast(x[i]) - mean[row_idx]) * var_val; } else { // d_bias != nullptr d_scale_or_d_bias_partial += static_cast(d_y[i]); @@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (HasDx) { if (scale != nullptr) { d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[blockIdx.x + col_offset]) / + static_cast(scale[blockIdx.x + col_offset]) * var_val); } else { - d_x[i] = static_cast(static_cast(d_y[i]) / var_val); + d_x[i] = static_cast(static_cast(d_y[i]) * var_val); } } } @@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX( U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { auto var_val = - static_cast(real_sqrt(static_cast(block_var) + epsilon)); + static_cast(rsqrt_(static_cast(block_var) + epsilon)); if (scale != nullptr) { int col_idx = i % feature_size; d_x[i] = static_cast(static_cast(d_y[i]) * - static_cast(scale[col_idx]) / var_val); + static_cast(scale[col_idx]) * var_val); } else { - d_x[i] = static_cast(static_cast(d_y[i]) / var_val); + d_x[i] = static_cast(static_cast(d_y[i]) * var_val); } d_x_mean_partial += static_cast(d_x[i]); d_x_var_partial += @@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; using ScaleBiasT = LayerNormScaleBiasT; if (idx < feature_size) { - auto var_val = - static_cast(real_sqrt(static_cast(var[0]) + epsilon)); + auto var_val = static_cast(rsqrt_(static_cast(var[0]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { - d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); + d_x[idx] = static_cast(static_cast(d_y[idx]) * var_val); } else { d_x[idx] = static_cast(static_cast(d_y[idx]) * - static_cast(scale[idx]) / var_val); + static_cast(scale[idx]) * var_val); } } if (d_scale != nullptr) { d_scale[idx] = static_cast(static_cast(d_y[idx]) * - (static_cast(x[idx]) - mean[0]) / var_val); + (static_cast(x[idx]) - mean[0]) * var_val); } if (d_bias != nullptr) { -- GitLab