diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 9d0229627728d798bf555b8c64f1b758c03b8f04..ba92c5dba718dba14aab6903a4960af328495631 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -160,8 +160,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): var_tmp1 = difference * difference variance = mean(var_tmp1, axis=axis, keepdim=True) var_tmp3 = variance + epsilon - sqrt_var = sqrt(var_tmp3) - out = difference / sqrt_var + rsqrt_var = rsqrt(var_tmp3) + out = difference * rsqrt_var if scale is not None: scale = reshape(scale, x.shape[begin_norm_axis:]) diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py index cc8ba89423d7c86c3b131c1f60b98f51a560e207..9f52d9d69ac23354e5b77413a3139b5a01c95dcf 100644 --- a/python/paddle/incubate/autograd/primitives.py +++ b/python/paddle/incubate/autograd/primitives.py @@ -50,6 +50,7 @@ from paddle.tensor import ones # noqa: F401 from paddle.tensor import pow # noqa: F401 from paddle.tensor import prod # noqa: F401 from paddle.tensor import reshape # noqa: F401 +from paddle.tensor import rsqrt # noqa: F401 from paddle.tensor import sign # noqa: F401 from paddle.tensor import sin # noqa: F401 from paddle.tensor import sinh # noqa: F401 @@ -117,6 +118,7 @@ sub_prim = [ 'ones', 'zeros', 'sqrt', + 'rsqrt', ] others = [