未验证 提交 a2060568 编写于 作者: H Huihuang Zheng 提交者: GitHub

Modify LayerNorm Composite Rule (#52712)

* [Do NOT merge] Expr PR on Composite

* Expr PR on Composite

* Revert some compsite experiment

* Remove unnecessary composite code

* Add rsqrt as sub primitives
上级 b0f17d05
......@@ -160,8 +160,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = sqrt(var_tmp3)
out = difference / sqrt_var
rsqrt_var = rsqrt(var_tmp3)
out = difference * rsqrt_var
if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:])
......
......@@ -50,6 +50,7 @@ from paddle.tensor import ones # noqa: F401
from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401
from paddle.tensor import reshape # noqa: F401
from paddle.tensor import rsqrt # noqa: F401
from paddle.tensor import sign # noqa: F401
from paddle.tensor import sin # noqa: F401
from paddle.tensor import sinh # noqa: F401
......@@ -117,6 +118,7 @@ sub_prim = [
'ones',
'zeros',
'sqrt',
'rsqrt',
]
others = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册