Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
82713eb6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82713eb6
编写于
2月 10, 2023
作者:
W
wangruting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init layer_norm
上级
637dfe49
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
77 addition
and
71 deletion
+77
-71
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
.../unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
+4
-4
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py
...unittests/prim/composite_ops/test_composite_layer_norm.py
+27
-27
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py
...ests/prim/composite_ops/test_composite_layer_norm_grad.py
+37
-31
python/paddle/incubate/autograd/composite_rules.py
python/paddle/incubate/autograd/composite_rules.py
+9
-9
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
浏览文件 @
82713eb6
...
@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
...
@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
n_shape
=
x
.
shape
[
1
:]
self
.
n_shape
=
self
.
x
.
shape
self
.
w
=
paddle
.
randn
([
4
])
self
.
w
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
...
@@ -86,7 +86,7 @@ class TestPrimForward(unittest.TestCase):
...
@@ -86,7 +86,7 @@ class TestPrimForward(unittest.TestCase):
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops
)
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops
)
def
test_cinn_prim_forward
(
self
):
def
test_cinn_prim_forward
(
self
):
dy_res
=
self
.
train
(
use_prim
=
False
)
dy_res
=
self
.
train
(
use_prim
=
False
)
cinn_res
=
self
.
train
(
use_prim
=
True
)
cinn_res
=
self
.
train
(
use_prim
=
True
)
...
@@ -94,7 +94,7 @@ class TestPrimForward(unittest.TestCase):
...
@@ -94,7 +94,7 @@ class TestPrimForward(unittest.TestCase):
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
cinn_res
[
i
],
dy_res
[
i
],
rtol
=
1e-6
,
atol
=
1e-6
cinn_res
[
i
],
dy_res
[
i
],
rtol
=
1e-6
,
atol
=
1e-6
)
)
class
TestPrimForwardAndBackward
(
unittest
.
TestCase
):
class
TestPrimForwardAndBackward
(
unittest
.
TestCase
):
"""
"""
...
@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
n_shape
=
x
.
shape
[
1
:]
self
.
n_shape
=
self
.
x
.
shape
self
.
w
=
paddle
.
randn
([
4
])
self
.
w
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
...
...
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py
浏览文件 @
82713eb6
...
@@ -20,7 +20,6 @@ from utils import TOLERANCE
...
@@ -20,7 +20,6 @@ from utils import TOLERANCE
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle
import
_C_ops
,
in_dynamic_mode
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
...
@@ -38,7 +37,6 @@ class Attr:
...
@@ -38,7 +37,6 @@ class Attr:
self
.
shape1
=
None
self
.
shape1
=
None
self
.
shape2
=
None
self
.
shape2
=
None
self
.
shape3
=
None
self
.
shape3
=
None
def
set_dtype
(
self
,
dtype
)
->
None
:
def
set_dtype
(
self
,
dtype
)
->
None
:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -66,14 +64,15 @@ attrs = Attr()
...
@@ -66,14 +64,15 @@ attrs = Attr()
def
fn
(
x
,
norm_shape
,
w
,
b
):
def
fn
(
x
,
norm_shape
,
w
,
b
):
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
def
layer_norm_
(
input
,
weight
,
bias
,
epsilon
=
1e-05
,
begin_norm_axis
=
0
):
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
input
.
shape
))
def
layer_norm_
(
input
,
weight
,
bias
,
epsilon
=
1e-05
,
begin_norm_axis
=
0
):
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
input
.
shape
))
mean
=
paddle
.
mean
(
input
,
axis
=
axis
,
keepdim
=
True
)
mean
=
paddle
.
mean
(
input
,
axis
=
axis
,
keepdim
=
True
)
t1
=
input
-
mean
t1
=
input
-
mean
t2
=
paddle
.
pow
(
t1
,
2.0
)
t2
=
paddle
.
pow
(
t1
,
2.0
)
t3
=
paddle
.
mean
(
t2
,
axis
=
axis
,
keepdim
=
True
)
t3
=
paddle
.
mean
(
t2
,
axis
=
axis
,
keepdim
=
True
)
t4
=
t3
+
epsilon
t4
=
t3
+
epsilon
t5
=
paddle
.
sqrt
(
t4
)
t5
=
paddle
.
sqrt
(
t4
)
t7
=
t1
/
t5
t7
=
t1
/
t5
out
=
t7
out
=
t7
if
weight
is
not
None
:
if
weight
is
not
None
:
...
@@ -82,15 +81,15 @@ def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
...
@@ -82,15 +81,15 @@ def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
if
bias
is
not
None
:
if
bias
is
not
None
:
bias
=
paddle
.
reshape
(
bias
,
input
.
shape
[
begin_norm_axis
:])
bias
=
paddle
.
reshape
(
bias
,
input
.
shape
[
begin_norm_axis
:])
out
=
out
+
paddle
.
broadcast_to
(
bias
,
out
.
shape
)
out
=
out
+
paddle
.
broadcast_to
(
bias
,
out
.
shape
)
return
out
return
out
def
composite_forward
(
x
,
norm_shape
,
w
,
b
):
def
composite_forward
(
x
,
norm_shape
,
w
,
b
):
b_axis
=
len
(
x
.
shape
)
-
len
(
norm_shape
)
b_axis
=
len
(
x
.
shape
)
-
len
(
norm_shape
)
return
layer_norm_
(
x
,
w
,
b
,
begin_norm_axis
=
b_axis
)
return
layer_norm_
(
x
,
w
,
b
,
begin_norm_axis
=
b_axis
)
def
expect_forward
(
x
,
norm_shape
,
w
,
b
):
def
expect_forward
(
x
,
norm_shape
,
w
,
b
):
return
fn
(
x
,
norm_shape
,
w
,
b
)
return
fn
(
x
,
norm_shape
,
w
,
b
)
...
@@ -98,10 +97,10 @@ def expect_forward(x, norm_shape, w, b):
...
@@ -98,10 +97,10 @@ def expect_forward(x, norm_shape, w, b):
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
n_shape
=
[[
3
,
4
],
[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],
[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
def
cal_composite
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
def
cal_composite
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -115,11 +114,9 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -115,11 +114,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w
=
paddle
.
static
.
data
(
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
)
b
=
paddle
.
static
.
data
(
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
))
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
blocks
=
main_program
.
blocks
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
...
@@ -135,13 +132,14 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -135,13 +132,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe
=
paddle
.
static
.
Executor
()
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
res
=
exe
.
run
(
main_program
,
main_program
,
feed
=
{
feed
=
{
'x'
:
inputs
,
'x'
:
inputs
,
'w'
:
weight
,
'w'
:
weight
,
'b'
:
bias
,
'b'
:
bias
,
},
},
fetch_list
=
[
y
])
fetch_list
=
[
y
],
)
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
core
.
_set_prim_forward_enabled
(
False
)
return
res
return
res
...
@@ -154,12 +152,9 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -154,12 +152,9 @@ class TestCompositelayer_norm(unittest.TestCase):
b_p
=
paddle
.
to_tensor
(
b
)
b_p
=
paddle
.
to_tensor
(
b
)
expect
=
expect_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
expect
=
expect_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
# actual = self.cal_composite(x_p, n_shape, w_p, b_p)
print
(
"expect = "
,
expect
)
#actual = self.cal_composite(x_p, n_shape, w_p, b_p)
actual
=
composite_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
actual
=
composite_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
print
(
"actual = "
,
actual
)
assert
expect
.
dtype
==
actual
.
dtype
assert
expect
.
dtype
==
actual
.
dtype
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
expect
,
expect
,
...
@@ -180,9 +175,14 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -180,9 +175,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def
test_forward
(
self
):
def
test_forward
(
self
):
for
j
in
self
.
dtypes
:
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
],
)
self
.
compare_forward
()
self
.
compare_forward
()
...
...
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py
浏览文件 @
82713eb6
...
@@ -20,7 +20,6 @@ from utils import TOLERANCE
...
@@ -20,7 +20,6 @@ from utils import TOLERANCE
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle
import
_C_ops
,
in_dynamic_mode
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
...
@@ -38,7 +37,6 @@ class Attr:
...
@@ -38,7 +37,6 @@ class Attr:
self
.
shape1
=
None
self
.
shape1
=
None
self
.
shape2
=
None
self
.
shape2
=
None
self
.
shape3
=
None
self
.
shape3
=
None
def
set_dtype
(
self
,
dtype
)
->
None
:
def
set_dtype
(
self
,
dtype
)
->
None
:
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -66,6 +64,7 @@ attrs = Attr()
...
@@ -66,6 +64,7 @@ attrs = Attr()
def
fn
(
x
,
norm_shape
,
w
,
b
):
def
fn
(
x
,
norm_shape
,
w
,
b
):
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
# def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
# def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
# axis = np.arange(begin_norm_axis,len(input.shape))
# axis = np.arange(begin_norm_axis,len(input.shape))
# mean = paddle.mean(input, axis=axis, keepdim=True)
# mean = paddle.mean(input, axis=axis, keepdim=True)
...
@@ -82,7 +81,7 @@ def fn(x, norm_shape, w, b):
...
@@ -82,7 +81,7 @@ def fn(x, norm_shape, w, b):
# if bias is not None:
# if bias is not None:
# bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
# bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
# out = out + paddle.broadcast_to(bias, out.shape)
# out = out + paddle.broadcast_to(bias, out.shape)
# return out
# return out
# def composite_forward(x, norm_shape, w, b):
# def composite_forward(x, norm_shape, w, b):
...
@@ -90,11 +89,10 @@ def fn(x, norm_shape, w, b):
...
@@ -90,11 +89,10 @@ def fn(x, norm_shape, w, b):
# return layer_norm_(x, w, b, begin_norm_axis=b_axis)
# return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def
expect_backward
(
x
,
norm_shape
,
w
,
b
):
def
expect_backward
(
x
,
norm_shape
,
w
,
b
):
paddle
.
disable_static
()
paddle
.
disable_static
()
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
res
=
fn
(
x
,
norm_shape
,
w
,
b
)
res
=
fn
(
x
,
norm_shape
,
w
,
b
)
gradients
=
paddle
.
grad
(
res
,
x
)
gradients
=
paddle
.
grad
(
res
,
x
)
return
gradients
return
gradients
...
@@ -103,10 +101,10 @@ def expect_backward(x, norm_shape, w, b):
...
@@ -103,10 +101,10 @@ def expect_backward(x, norm_shape, w, b):
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
n_shape
=
[[
3
,
4
],
[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],
[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -121,11 +119,9 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -121,11 +119,9 @@ class TestCompositelayer_norm(unittest.TestCase):
w
=
paddle
.
static
.
data
(
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
)
b
=
paddle
.
static
.
data
(
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
))
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
blocks
=
main_program
.
blocks
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
...
@@ -147,13 +143,14 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -147,13 +143,14 @@ class TestCompositelayer_norm(unittest.TestCase):
exe
=
paddle
.
static
.
Executor
()
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
res
=
exe
.
run
(
main_program
,
main_program
,
feed
=
{
feed
=
{
'x'
:
inputs
,
'x'
:
inputs
,
'w'
:
weight
,
'w'
:
weight
,
'b'
:
bias
,
'b'
:
bias
,
},
},
fetch_list
=
[
z
])
fetch_list
=
[
z
],
)
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
core
.
_set_prim_forward_enabled
(
False
)
return
res
return
res
...
@@ -188,9 +185,14 @@ class TestCompositelayer_norm(unittest.TestCase):
...
@@ -188,9 +185,14 @@ class TestCompositelayer_norm(unittest.TestCase):
def
test_backward
(
self
):
def
test_backward
(
self
):
for
j
in
self
.
dtypes
:
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
],
)
self
.
compare_backward
()
self
.
compare_backward
()
...
@@ -198,10 +200,10 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
...
@@ -198,10 +200,10 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
core
.
_set_prim_backward_enabled
(
True
)
core
.
_set_prim_backward_enabled
(
True
)
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
n_shape
=
[[
3
,
4
],
[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],
[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape2s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
self
.
shape3s
=
[[
12
],
[
3
],
[
6
]]
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -216,11 +218,9 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
...
@@ -216,11 +218,9 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
w
=
paddle
.
static
.
data
(
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
)
b
=
paddle
.
static
.
data
(
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
))
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
blocks
=
main_program
.
blocks
paddle
.
incubate
.
autograd
.
to_prim
(
blocks
)
paddle
.
incubate
.
autograd
.
to_prim
(
blocks
)
z
=
paddle
.
static
.
gradients
([
y
],
x
)
z
=
paddle
.
static
.
gradients
([
y
],
x
)
...
@@ -228,13 +228,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
...
@@ -228,13 +228,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
exe
=
paddle
.
static
.
Executor
()
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
res
=
exe
.
run
(
main_program
,
main_program
,
feed
=
{
feed
=
{
'x'
:
inputs
,
'x'
:
inputs
,
'w'
:
weight
,
'w'
:
weight
,
'b'
:
bias
,
'b'
:
bias
,
},
},
fetch_list
=
[
z
])
fetch_list
=
[
z
],
)
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
_set_prim_all_enabled
(
False
)
core
.
_set_prim_all_enabled
(
False
)
return
res
return
res
...
@@ -269,9 +270,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
...
@@ -269,9 +270,14 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase):
def
test_prim_backward
(
self
):
def
test_prim_backward
(
self
):
for
j
in
self
.
dtypes
:
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
],
)
self
.
compare_backward
()
self
.
compare_backward
()
...
...
python/paddle/incubate/autograd/composite_rules.py
浏览文件 @
82713eb6
...
@@ -104,21 +104,21 @@ def composite_batchnorm(
...
@@ -104,21 +104,21 @@ def composite_batchnorm(
@
REGISTER_COMPOSITE
(
'layer_norm'
)
@
REGISTER_COMPOSITE
(
'layer_norm'
)
def
layernorm_composite
(
x
,
scale
,
bias
,
epsilon
,
begin_norm_axis
):
def
layernorm_composite
(
x
,
scale
,
bias
,
epsilon
,
begin_norm_axis
):
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
x
.
shape
))
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
x
.
shape
))
mean_
=
mean
(
x
,
axis
=
axis
,
keepdim
=
True
)
mean_
=
mean
(
x
,
axis
=
axis
,
keepdim
=
True
)
difference
=
x
-
mean_
difference
=
x
-
mean_
var_tmp1
=
pow
(
difference
,
2.0
)
var_tmp1
=
pow
(
difference
,
2.0
)
variance
=
mean
(
var_tmp1
,
axis
=
axis
,
keepdim
=
True
)
variance
=
mean
(
var_tmp1
,
axis
=
axis
,
keepdim
=
True
)
var_tmp3
=
variance
+
epsilon
var_tmp3
=
variance
+
epsilon
sqrt_var
=
sqrt
(
var_tmp3
)
sqrt_var
=
sqrt
(
var_tmp3
)
out
=
difference
/
sqrt_var
out
=
difference
/
sqrt_var
if
scale
is
not
None
:
if
scale
is
not
None
:
scale
=
reshape
(
scale
,
x
.
shape
[
begin_norm_axis
:])
scale
=
reshape
(
scale
,
x
.
shape
[
begin_norm_axis
:])
out
=
t7
*
broadcast_to
(
scale
,
out
.
shape
)
out
=
out
*
broadcast_to
(
scale
,
out
.
shape
)
if
bias
is
not
None
:
if
bias
is
not
None
:
bias
=
reshape
(
bias
,
x
.
shape
[
begin_norm_axis
:])
bias
=
reshape
(
bias
,
x
.
shape
[
begin_norm_axis
:])
out
=
out
+
broadcast_to
(
bias
,
out
.
shape
)
out
=
out
+
broadcast_to
(
bias
,
out
.
shape
)
return
out
,
mean_
,
variance
return
out
,
mean_
,
variance
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录