diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 37f7e8c55901c4a526db937ee2aa034d999167bf..06c1eaf881626cbea7e872291119682766da6f33 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -705,7 +705,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < feature_size) { auto var_val = - static_cast(real_sqrt(static_cast(var[idx]) + epsilon)); + static_cast(real_sqrt(static_cast(var[0]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); @@ -717,7 +717,7 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( if (d_scale != nullptr) { d_scale[idx] = static_cast(d_y[idx]) * - (static_cast(x[idx]) - mean[idx]) / var_val; + (static_cast(x[idx]) - mean[0]) / var_val; } if (d_bias != nullptr) d_bias[idx] = static_cast(d_y[idx]); diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 51224002c96039b0679d605a068f063883d1cd52..98a503eb1ea6f6dd6ecd49579a231cf0e52b7b73 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -233,6 +233,7 @@ class TestLayerNormOp(unittest.TestCase): test_with_place(place, shape, begin_norm_axis) def test_check_forward_backward_with_scale_and_bias(self): + self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward( shape=[2, 3, 4, 5],