未验证 提交 a0aff194 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix the calculation of layer_norm_bwd (#53224)

* Fix the calculation of layer_norm_bwd

* fix
上级 bfa5d6b8
...@@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll( ...@@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll(
for (int64_t i = beg_idx; i < end_idx; i += stride) { for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon); auto var_val = rsqrt_(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += static_cast<U>(d_y[i]) * d_scale_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) / var_val; (static_cast<U>(x[i]) - mean[row_idx]) * var_val;
d_bias_partial += static_cast<U>(d_y[i]); d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) { if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) / static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val); var_val);
} }
} }
...@@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
for (int64_t i = beg_idx; i < end_idx; i += stride) { for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon)); static_cast<U>(rsqrt_(static_cast<float>(var[row_idx]) + epsilon));
if (HasDScale) { if (HasDScale) {
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) * d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) / (static_cast<U>(x[i]) - mean[row_idx]) *
var_val; var_val;
} else { // d_bias != nullptr } else { // d_bias != nullptr
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]); d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
...@@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) { if (HasDx) {
if (scale != nullptr) { if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) / static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val); var_val);
} else { } else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val); d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
} }
} }
} }
...@@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX( ...@@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX(
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0); U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon)); static_cast<U>(rsqrt_(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) { if (scale != nullptr) {
int col_idx = i % feature_size; int col_idx = i % feature_size;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[col_idx]) / var_val); static_cast<U>(scale[col_idx]) * var_val);
} else { } else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val); d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
} }
d_x_mean_partial += static_cast<U>(d_x[i]); d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial += d_x_var_partial +=
...@@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>; using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) { if (idx < feature_size) {
auto var_val = auto var_val = static_cast<U>(rsqrt_(static_cast<float>(var[0]) + epsilon));
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
if (d_x != nullptr) { if (d_x != nullptr) {
if (d_scale == nullptr) { if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val); d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) * var_val);
} else { } else {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) * d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<U>(scale[idx]) / var_val); static_cast<U>(scale[idx]) * var_val);
} }
} }
if (d_scale != nullptr) { if (d_scale != nullptr) {
d_scale[idx] = d_scale[idx] =
static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) * static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val); (static_cast<U>(x[idx]) - mean[0]) * var_val);
} }
if (d_bias != nullptr) { if (d_bias != nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册