未验证 提交 ad5f7494 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix the bug of layer_norm when batch_size=1 (#35480)

The bug is that access to mean and var is incorrect, and the array will be out of bounds: the shape of mean and var is [batch_size], and the range of thread idx is 0~feature_size, so mean[idx] and var[idx] is incorrect.

When batch_size=1, the correct access is mean[0] and var[0], and a unit test with batch_size=1 is added.
上级 4e62af80
...@@ -705,7 +705,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -705,7 +705,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < feature_size) { if (idx < feature_size) {
auto var_val = auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon)); static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
if (d_x != nullptr) { if (d_x != nullptr) {
if (d_scale == nullptr) { if (d_scale == nullptr) {
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val); d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
...@@ -717,7 +717,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -717,7 +717,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
if (d_scale != nullptr) { if (d_scale != nullptr) {
d_scale[idx] = static_cast<U>(d_y[idx]) * d_scale[idx] = static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[idx]) / var_val; (static_cast<U>(x[idx]) - mean[0]) / var_val;
} }
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]); if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
......
...@@ -233,6 +233,7 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -233,6 +233,7 @@ class TestLayerNormOp(unittest.TestCase):
test_with_place(place, shape, begin_norm_axis) test_with_place(place, shape, begin_norm_axis)
def test_check_forward_backward_with_scale_and_bias(self): def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward( self.check_forward_backward(
shape=[2, 3, 4, 5], shape=[2, 3, 4, 5],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册