未验证 提交 efeeb6fb 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix the calculation of layer_norm_bwd (#53230)

上级 94e8fc78
......@@ -1397,13 +1397,13 @@ __global__ void LayerNormBackwardGradientAll(
for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
auto var_val = rsqrt_(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) / var_val;
(static_cast<U>(x[i]) - mean[row_idx]) * var_val;
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) /
static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val);
}
}
......@@ -1453,10 +1453,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
for (int64_t i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
static_cast<U>(rsqrt_(static_cast<float>(var[row_idx]) + epsilon));
if (HasDScale) {
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) /
(static_cast<U>(x[i]) - mean[row_idx]) *
var_val;
} else { // d_bias != nullptr
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
......@@ -1465,10 +1465,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if (HasDx) {
if (scale != nullptr) {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[blockIdx.x + col_offset]) /
static_cast<U>(scale[blockIdx.x + col_offset]) *
var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
}
}
}
......@@ -1556,13 +1556,13 @@ __global__ void LayerNormBackwardGradientOnlyDX(
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
static_cast<U>(rsqrt_(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) {
int col_idx = i % feature_size;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
static_cast<U>(scale[col_idx]) / var_val);
static_cast<U>(scale[col_idx]) * var_val);
} else {
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) * var_val);
}
d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial +=
......@@ -1606,21 +1606,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
if (idx < feature_size) {
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
auto var_val = static_cast<U>(rsqrt_(static_cast<float>(var[0]) + epsilon));
if (d_x != nullptr) {
if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) * var_val);
} else {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
static_cast<U>(scale[idx]) / var_val);
static_cast<U>(scale[idx]) * var_val);
}
}
if (d_scale != nullptr) {
d_scale[idx] =
static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[0]) / var_val);
(static_cast<U>(x[idx]) - mean[0]) * var_val);
}
if (d_bias != nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册