From a20605682ea1fab07c861123676719f9cc97527a Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 12 Apr 2023 10:20:50 +0800 Subject: [PATCH] Modify LayerNorm Composite Rule (#52712) * [Do NOT merge] Expr PR on Composite * Expr PR on Composite * Revert some compsite experiment * Remove unnecessary composite code * Add rsqrt as sub primitives --- python/paddle/incubate/autograd/composite_rules.py | 4 ++-- python/paddle/incubate/autograd/primitives.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 9d022962772..ba92c5dba71 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -160,8 +160,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): var_tmp1 = difference * difference variance = mean(var_tmp1, axis=axis, keepdim=True) var_tmp3 = variance + epsilon - sqrt_var = sqrt(var_tmp3) - out = difference / sqrt_var + rsqrt_var = rsqrt(var_tmp3) + out = difference * rsqrt_var if scale is not None: scale = reshape(scale, x.shape[begin_norm_axis:]) diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py index cc8ba89423d..9f52d9d69ac 100644 --- a/python/paddle/incubate/autograd/primitives.py +++ b/python/paddle/incubate/autograd/primitives.py @@ -50,6 +50,7 @@ from paddle.tensor import ones # noqa: F401 from paddle.tensor import pow # noqa: F401 from paddle.tensor import prod # noqa: F401 from paddle.tensor import reshape # noqa: F401 +from paddle.tensor import rsqrt # noqa: F401 from paddle.tensor import sign # noqa: F401 from paddle.tensor import sin # noqa: F401 from paddle.tensor import sinh # noqa: F401 @@ -117,6 +118,7 @@ sub_prim = [ 'ones', 'zeros', 'sqrt', + 'rsqrt', ] others = [ -- GitLab