diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 0d877fe23244474eacfe0cb7cfd3f565436dad81..bc8860eaa055e391c9b5212eb2447d55a932873e 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -135,7 +135,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, } __syncthreads(); mean_val = mean[blockIdx.x]; - var_val = static_cast(real_sqrt(var[blockIdx.x]) + epsilon); + var_val = static_cast(real_sqrt(var[blockIdx.x] + epsilon)); // Step 2: Calculate y if (scale != nullptr) { diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index d17942fe3be1e89b9ae29ebb7bdd8a52efb5c5ea..51224002c96039b0679d605a068f063883d1cd52 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -211,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase): for name in ['x', 'scale', 'bias', 'y@GRAD'] }, fetch_list=fetch_list) - self.__assert_close(y, out[0], "y", 1e-3) + self.__assert_close(y, out[0], "y") self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(x_grad, out[3], "x_grad")