From ad5f749448f1caf9232f89b96b62a0f2b7a4c2ef Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 8 Sep 2021 11:52:47 +0800 Subject: [PATCH] fix the bug of layer_norm when batch_size=1 (#35480) The bug is that access to mean and var is incorrect, and the array will be out of bounds: the shape of mean and var is [batch_size], and the range of thread idx is 0~feature_size, so mean[idx] and var[idx] is incorrect. When batch_size=1, the correct access is mean[0] and var[0], and a unit test with batch_size=1 is added. --- paddle/fluid/operators/layer_norm_kernel.cu.h | 4 ++-- python/paddle/fluid/tests/unittests/test_layer_norm_op.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 37f7e8c5590..06c1eaf8816 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -705,7 +705,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < feature_size) { auto var_val = - static_cast(real_sqrt(static_cast(var[idx]) + epsilon)); + static_cast(real_sqrt(static_cast(var[0]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); @@ -717,7 +717,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( if (d_scale != nullptr) { d_scale[idx] = static_cast(d_y[idx]) * - (static_cast(x[idx]) - mean[idx]) / var_val; + (static_cast(x[idx]) - mean[0]) / var_val; } if (d_bias != nullptr) d_bias[idx] = static_cast(d_y[idx]); diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 51224002c96..98a503eb1ea 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -233,6 +233,7 @@ class TestLayerNormOp(unittest.TestCase): test_with_place(place, shape, begin_norm_axis) def test_check_forward_backward_with_scale_and_bias(self): + self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward( shape=[2, 3, 4, 5], -- GitLab