Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
637dfe49
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
637dfe49
编写于
2月 10, 2023
作者:
W
wangruting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init layer_norm
上级
350cd82a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
656 addition
and
2 deletion
+656
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
.../unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
+158
-0
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py
...unittests/prim/composite_ops/test_composite_layer_norm.py
+190
-0
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py
...ests/prim/composite_ops/test_composite_layer_norm_grad.py
+279
-0
python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py
.../paddle/fluid/tests/unittests/prim/composite_ops/utils.py
+6
-0
python/paddle/incubate/autograd/composite_rules.py
python/paddle/incubate/autograd/composite_rules.py
+23
-2
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py
0 → 100644
浏览文件 @
637dfe49
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
platform
import
unittest
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
def
apply_to_static
(
net
,
use_cinn
):
build_strategy
=
paddle
.
static
.
BuildStrategy
()
build_strategy
.
build_cinn_pass
=
use_cinn
return
paddle
.
jit
.
to_static
(
net
,
build_strategy
=
build_strategy
)
class
PrimeNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
PrimeNet
,
self
).
__init__
()
self
.
fc
=
paddle
.
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
,
n_shape
,
w
,
b
):
y
=
self
.
fc
(
x
)
out
=
F
.
layer_norm
(
y
,
n_shape
,
w
,
b
)
return
out
[
0
]
class
TestPrimForward
(
unittest
.
TestCase
):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def
setUp
(
self
):
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
n_shape
=
x
.
shape
[
1
:]
self
.
w
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
x
.
stop_gradient
=
False
def
train
(
self
,
use_prim
):
paddle
.
seed
(
2022
)
net
=
PrimeNet
()
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
core
.
_set_prim_forward_enabled
(
use_prim
)
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
res
=
[]
for
_
in
range
(
10
):
out
=
net
(
self
.
x
,
self
.
n_shape
,
self
.
w
,
self
.
b
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
sgd
.
step
()
sgd
.
clear_grad
()
res
.
append
(
out
.
numpy
())
self
.
check_prim
(
net
,
use_prim
)
return
res
def
check_prim
(
self
,
net
,
use_prim
):
if
not
use_prim
:
return
fwd_ops
=
[
op
.
type
for
op
in
net
.
forward
.
main_program
.
block
(
0
).
ops
]
# Ensure that layer_norm is splitted into small ops
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops
)
def
test_cinn_prim_forward
(
self
):
dy_res
=
self
.
train
(
use_prim
=
False
)
cinn_res
=
self
.
train
(
use_prim
=
True
)
for
i
in
range
(
len
(
dy_res
)):
np
.
testing
.
assert_allclose
(
cinn_res
[
i
],
dy_res
[
i
],
rtol
=
1e-6
,
atol
=
1e-6
)
class
TestPrimForwardAndBackward
(
unittest
.
TestCase
):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def
setUp
(
self
):
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
n_shape
=
x
.
shape
[
1
:]
self
.
w
=
paddle
.
randn
([
4
])
self
.
b
=
paddle
.
randn
([
4
])
self
.
x
.
stop_gradient
=
False
def
train
(
self
,
use_prim
):
paddle
.
seed
(
2022
)
net
=
PrimeNet
()
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
core
.
_set_prim_all_enabled
(
use_prim
)
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
res
=
[]
for
_
in
range
(
10
):
out
=
net
(
self
.
x
,
self
.
n_shape
,
self
.
w
,
self
.
b
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
sgd
.
step
()
sgd
.
clear_grad
()
res
.
append
(
out
.
numpy
())
self
.
check_prim
(
net
,
use_prim
)
return
res
def
check_prim
(
self
,
net
,
use_prim
):
if
not
use_prim
:
return
fwd_ops
=
[
op
.
type
for
op
in
net
.
forward
.
main_program
.
block
(
0
).
ops
]
# Ensure that layer_norm is splitted into small ops
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops
)
def
test_cinn_prim_forward
(
self
):
plat
=
platform
.
system
()
if
plat
==
"Linux"
:
dy_res
=
self
.
train
(
use_prim
=
False
)
cinn_res
=
self
.
train
(
use_prim
=
True
)
for
i
in
range
(
len
(
dy_res
)):
np
.
testing
.
assert_allclose
(
cinn_res
[
i
],
dy_res
[
i
],
rtol
=
1e-6
,
atol
=
1e-6
)
else
:
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py
0 → 100644
浏览文件 @
637dfe49
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
utils
import
TOLERANCE
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
from
paddle
import
_C_ops
,
in_dynamic_mode
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
np
.
random
.
seed
(
100
)
np_data1
=
np
.
random
.
random
(
shape1
).
astype
(
dtype
)
np_data2
=
np
.
random
.
random
(
shape2
).
astype
(
dtype
)
np_data3
=
np
.
random
.
random
(
shape3
).
astype
(
dtype
)
return
np_data1
,
np_data2
,
np_data3
class
Attr
:
def
__init__
(
self
)
->
None
:
self
.
dtype
=
None
self
.
n_shape
=
None
self
.
shape1
=
None
self
.
shape2
=
None
self
.
shape3
=
None
def
set_dtype
(
self
,
dtype
)
->
None
:
self
.
dtype
=
dtype
return
def
set_shape
(
self
,
n_shape
,
shape1
,
shape2
,
shape3
)
->
None
:
self
.
n_shape
=
n_shape
self
.
shape1
=
shape1
self
.
shape2
=
shape2
self
.
shape3
=
shape3
return
def
get_rtol
(
self
,
flag
):
rtol
=
TOLERANCE
[
self
.
dtype
][
flag
].
get
(
"rtol"
)
return
rtol
def
get_atol
(
self
,
flag
):
atol
=
TOLERANCE
[
self
.
dtype
][
flag
].
get
(
"atol"
)
return
atol
attrs
=
Attr
()
def
fn
(
x
,
norm_shape
,
w
,
b
):
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
def
layer_norm_
(
input
,
weight
,
bias
,
epsilon
=
1e-05
,
begin_norm_axis
=
0
):
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
input
.
shape
))
mean
=
paddle
.
mean
(
input
,
axis
=
axis
,
keepdim
=
True
)
t1
=
input
-
mean
t2
=
paddle
.
pow
(
t1
,
2.0
)
t3
=
paddle
.
mean
(
t2
,
axis
=
axis
,
keepdim
=
True
)
t4
=
t3
+
epsilon
t5
=
paddle
.
sqrt
(
t4
)
t7
=
t1
/
t5
out
=
t7
if
weight
is
not
None
:
weight
=
paddle
.
reshape
(
weight
,
input
.
shape
[
begin_norm_axis
:])
out
=
t7
*
paddle
.
broadcast_to
(
weight
,
out
.
shape
)
if
bias
is
not
None
:
bias
=
paddle
.
reshape
(
bias
,
input
.
shape
[
begin_norm_axis
:])
out
=
out
+
paddle
.
broadcast_to
(
bias
,
out
.
shape
)
return
out
def
composite_forward
(
x
,
norm_shape
,
w
,
b
):
b_axis
=
len
(
x
.
shape
)
-
len
(
norm_shape
)
return
layer_norm_
(
x
,
w
,
b
,
begin_norm_axis
=
b_axis
)
def
expect_forward
(
x
,
norm_shape
,
w
,
b
):
return
fn
(
x
,
norm_shape
,
w
,
b
)
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],[
3
],[
6
]]
self
.
shape3s
=
[[
12
],[
3
],[
6
]]
def
cal_composite
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
core
.
_set_prim_forward_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
x
=
paddle
.
static
.
data
(
'x'
,
shape
=
inputs
.
shape
,
dtype
=
str
(
inputs
.
dtype
)
)
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
# Ensure that layer_norm in original block
self
.
assertTrue
(
'layer_norm'
in
fwd_ops
)
paddle
.
incubate
.
autograd
.
to_prim
(
blocks
)
fwd_ops_new
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
# Ensure that layer_norm is splitted into small ops
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops_new
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
,
'w'
:
weight
,
'b'
:
bias
,
},
fetch_list
=
[
y
])
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
return
res
def
compare_forward
(
self
):
x
,
w
,
b
=
generate_data
(
attrs
.
shape1
,
attrs
.
shape2
,
attrs
.
shape3
)
n_shape
=
attrs
.
n_shape
x_p
=
paddle
.
to_tensor
(
x
)
w_p
=
paddle
.
to_tensor
(
w
)
b_p
=
paddle
.
to_tensor
(
b
)
expect
=
expect_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
print
(
"expect = "
,
expect
)
#actual = self.cal_composite(x_p, n_shape, w_p, b_p)
actual
=
composite_forward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
print
(
"actual = "
,
actual
)
assert
expect
.
dtype
==
actual
.
dtype
np
.
testing
.
assert_allclose
(
expect
,
actual
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
expect_2
=
expect_forward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
actual_2
=
composite_forward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
assert
expect_2
.
dtype
==
actual_2
.
dtype
np
.
testing
.
assert_allclose
(
expect_2
,
actual_2
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
def
test_forward
(
self
):
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
self
.
compare_forward
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py
0 → 100644
浏览文件 @
637dfe49
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
utils
import
TOLERANCE
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
from
paddle
import
_C_ops
,
in_dynamic_mode
def
generate_data
(
shape1
,
shape2
,
shape3
,
dtype
=
"float32"
):
np
.
random
.
seed
(
100
)
np_data1
=
np
.
random
.
random
(
shape1
).
astype
(
dtype
)
np_data2
=
np
.
random
.
random
(
shape2
).
astype
(
dtype
)
np_data3
=
np
.
random
.
random
(
shape3
).
astype
(
dtype
)
return
np_data1
,
np_data2
,
np_data3
class
Attr
:
def
__init__
(
self
)
->
None
:
self
.
dtype
=
None
self
.
n_shape
=
None
self
.
shape1
=
None
self
.
shape2
=
None
self
.
shape3
=
None
def
set_dtype
(
self
,
dtype
)
->
None
:
self
.
dtype
=
dtype
return
def
set_shape
(
self
,
n_shape
,
shape1
,
shape2
,
shape3
)
->
None
:
self
.
n_shape
=
n_shape
self
.
shape1
=
shape1
self
.
shape2
=
shape2
self
.
shape3
=
shape3
return
def
get_rtol
(
self
,
flag
):
rtol
=
TOLERANCE
[
self
.
dtype
][
flag
].
get
(
"rtol"
)
return
rtol
def
get_atol
(
self
,
flag
):
atol
=
TOLERANCE
[
self
.
dtype
][
flag
].
get
(
"atol"
)
return
atol
attrs
=
Attr
()
def
fn
(
x
,
norm_shape
,
w
,
b
):
return
F
.
layer_norm
(
x
,
norm_shape
,
w
,
b
)
# def layer_norm_ (input, weight, bias, epsilon=1e-05, begin_norm_axis = 0):
# axis = np.arange(begin_norm_axis,len(input.shape))
# mean = paddle.mean(input, axis=axis, keepdim=True)
# t1 = input - mean
# t2 = paddle.pow( t1, 2.0)
# t3 = paddle.mean( t2, axis=axis, keepdim=True)
# t4 = t3 + epsilon
# t5 = paddle.sqrt( t4 )
# t7 = t1 / t5
# out = t7
# if weight is not None:
# weight = paddle.reshape(weight, input.shape[begin_norm_axis:])
# out = t7 * paddle.broadcast_to(weight, out.shape)
# if bias is not None:
# bias = paddle.reshape(bias, input.shape[begin_norm_axis:])
# out = out + paddle.broadcast_to(bias, out.shape)
# return out
# def composite_forward(x, norm_shape, w, b):
# b_axis = len(x.shape) - len(norm_shape)
# return layer_norm_(x, w, b, begin_norm_axis=b_axis)
def
expect_backward
(
x
,
norm_shape
,
w
,
b
):
paddle
.
disable_static
()
x
.
stop_gradient
=
False
res
=
fn
(
x
,
norm_shape
,
w
,
b
)
gradients
=
paddle
.
grad
(
res
,
x
)
return
gradients
class
TestCompositelayer_norm
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],[
3
],[
6
]]
self
.
shape3s
=
[[
12
],[
3
],[
6
]]
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
core
.
_set_prim_forward_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
x
=
paddle
.
static
.
data
(
'x'
,
shape
=
inputs
.
shape
,
dtype
=
str
(
inputs
.
dtype
)
)
x
.
stop_gradient
=
False
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
fwd_ops
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
# Ensure that layer_norm in original block
self
.
assertTrue
(
'layer_norm'
in
fwd_ops
)
paddle
.
incubate
.
autograd
.
to_prim
(
blocks
)
fwd_ops_new
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
# Ensure that layer_norm is splitted into small ops
self
.
assertTrue
(
'layer_norm'
not
in
fwd_ops_new
)
z
=
paddle
.
static
.
gradients
([
y
],
x
)
fwd_ops_grad
=
[
op
.
type
for
op
in
blocks
[
0
].
ops
]
# Ensure that layer_norm_grad not in grad block
self
.
assertTrue
(
'layer_norm_grad'
not
in
fwd_ops_grad
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
,
'w'
:
weight
,
'b'
:
bias
,
},
fetch_list
=
[
z
])
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
return
res
def
compare_backward
(
self
):
x
,
w
,
b
=
generate_data
(
attrs
.
shape1
,
attrs
.
shape2
,
attrs
.
shape3
)
n_shape
=
attrs
.
n_shape
x_p
=
paddle
.
to_tensor
(
x
)
w_p
=
paddle
.
to_tensor
(
w
)
b_p
=
paddle
.
to_tensor
(
b
)
expect
=
expect_backward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
actual
=
self
.
cal_composite_backward
(
x_p
,
n_shape
,
w_p
,
b_p
)
assert
expect
.
dtype
==
actual
.
dtype
np
.
testing
.
assert_allclose
(
expect
,
actual
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
expect_2
=
expect_backward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
actual_2
=
self
.
cal_composite_backward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
assert
expect_2
.
dtype
==
actual_2
.
dtype
np
.
testing
.
assert_allclose
(
expect_2
,
actual_2
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
def
test_backward
(
self
):
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
self
.
compare_backward
()
class
TestCompositelayer_normPrimBackward
(
unittest
.
TestCase
):
def
setUp
(
self
):
core
.
_set_prim_backward_enabled
(
True
)
self
.
dtypes
=
[
"float16"
,
"float32"
]
self
.
n_shape
=
[[
3
,
4
],[
3
],
[
2
,
3
]]
self
.
shape1s
=
[[
3
,
4
],[
2
,
4
,
3
],
[
2
,
2
,
3
]]
self
.
shape2s
=
[[
12
],[
3
],[
6
]]
self
.
shape3s
=
[[
12
],[
3
],[
6
]]
def
cal_composite_backward
(
self
,
inputs
,
norm_shape
,
weight
,
bias
):
paddle
.
enable_static
()
core
.
_set_prim_all_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
x
=
paddle
.
static
.
data
(
'x'
,
shape
=
inputs
.
shape
,
dtype
=
str
(
inputs
.
dtype
)
)
x
.
stop_gradient
=
False
w
=
paddle
.
static
.
data
(
'w'
,
shape
=
weight
.
shape
,
dtype
=
str
(
weight
.
dtype
)
)
b
=
paddle
.
static
.
data
(
'b'
,
shape
=
bias
.
shape
,
dtype
=
str
(
bias
.
dtype
)
)
y
=
fn
(
x
,
norm_shape
,
w
,
b
)
blocks
=
main_program
.
blocks
paddle
.
incubate
.
autograd
.
to_prim
(
blocks
)
z
=
paddle
.
static
.
gradients
([
y
],
x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
,
'w'
:
weight
,
'b'
:
bias
,
},
fetch_list
=
[
z
])
paddle
.
disable_static
()
core
.
_set_prim_all_enabled
(
False
)
return
res
def
compare_backward
(
self
):
x
,
w
,
b
=
generate_data
(
attrs
.
shape1
,
attrs
.
shape2
,
attrs
.
shape3
)
n_shape
=
attrs
.
n_shape
x_p
=
paddle
.
to_tensor
(
x
)
w_p
=
paddle
.
to_tensor
(
w
)
b_p
=
paddle
.
to_tensor
(
b
)
expect
=
expect_backward
(
x_p
,
n_shape
,
w_p
,
b_p
).
numpy
()
actual
=
self
.
cal_composite_backward
(
x_p
,
n_shape
,
w_p
,
b_p
)
assert
expect
.
dtype
==
actual
.
dtype
np
.
testing
.
assert_allclose
(
expect
,
actual
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
expect_2
=
expect_backward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
actual_2
=
self
.
cal_composite_backward
(
x_p
,
n_shape
,
None
,
None
).
numpy
()
assert
expect_2
.
dtype
==
actual_2
.
dtype
np
.
testing
.
assert_allclose
(
expect_2
,
actual_2
,
rtol
=
attrs
.
get_rtol
(
"forward"
),
atol
=
attrs
.
get_atol
(
"forward"
),
)
def
test_prim_backward
(
self
):
for
j
in
self
.
dtypes
:
for
t
in
range
(
0
,
len
(
self
.
shape1s
)):
attrs
.
set_dtype
(
j
)
attrs
.
set_shape
(
self
.
n_shape
[
t
],
self
.
shape1s
[
t
],
self
.
shape2s
[
t
],
self
.
shape3s
[
t
])
self
.
compare_backward
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py
浏览文件 @
637dfe49
...
@@ -14,6 +14,12 @@
...
@@ -14,6 +14,12 @@
# default tolerance
# default tolerance
TOLERANCE
=
{
TOLERANCE
=
{
"float16"
:
{
"forward"
:
{
"rtol"
:
1e-3
,
"atol"
:
1e-3
},
"backward"
:
{
"rtol"
:
1e-3
,
"atol"
:
1e-3
},
"prim_backward"
:
{
"rtol"
:
1e-3
,
"atol"
:
1e-3
},
},
"float32"
:
{
"float32"
:
{
"forward"
:
{
"rtol"
:
1e-6
,
"atol"
:
1e-6
},
"forward"
:
{
"rtol"
:
1e-6
,
"atol"
:
1e-6
},
"backward"
:
{
"rtol"
:
1e-6
,
"atol"
:
1e-6
},
"backward"
:
{
"rtol"
:
1e-6
,
"atol"
:
1e-6
},
...
...
python/paddle/incubate/autograd/composite_rules.py
浏览文件 @
637dfe49
...
@@ -33,8 +33,8 @@ def softmax_composite(x, axis):
...
@@ -33,8 +33,8 @@ def softmax_composite(x, axis):
max_temp
=
max
(
x
,
axis
,
keepdim
=
True
)
max_temp
=
max
(
x
,
axis
,
keepdim
=
True
)
max_temp
.
stop_gradient
=
True
max_temp
.
stop_gradient
=
True
molecular
=
exp
(
x
-
max_temp
)
molecular
=
exp
(
x
-
max_temp
)
denominato
r
=
sum
(
molecular
,
axis
=
axis
,
keepdim
=
True
)
sqrt_va
r
=
sum
(
molecular
,
axis
=
axis
,
keepdim
=
True
)
res
=
divide
(
molecular
,
denominato
r
)
res
=
divide
(
molecular
,
sqrt_va
r
)
return
res
return
res
...
@@ -101,3 +101,24 @@ def composite_batchnorm(
...
@@ -101,3 +101,24 @@ def composite_batchnorm(
return
run_mean_
,
None
,
batch_mean_
,
batch_var_
,
run_var_
,
y
return
run_mean_
,
None
,
batch_mean_
,
batch_var_
,
run_var_
,
y
else
:
else
:
return
run_mean_
,
batch_mean_
,
batch_var_
,
run_var_
,
y
return
run_mean_
,
batch_mean_
,
batch_var_
,
run_var_
,
y
@
REGISTER_COMPOSITE
(
'layer_norm'
)
def
layernorm_composite
(
x
,
scale
,
bias
,
epsilon
,
begin_norm_axis
):
axis
=
np
.
arange
(
begin_norm_axis
,
len
(
x
.
shape
))
mean_
=
mean
(
x
,
axis
=
axis
,
keepdim
=
True
)
difference
=
x
-
mean_
var_tmp1
=
pow
(
difference
,
2.0
)
variance
=
mean
(
var_tmp1
,
axis
=
axis
,
keepdim
=
True
)
var_tmp3
=
variance
+
epsilon
sqrt_var
=
sqrt
(
var_tmp3
)
out
=
difference
/
sqrt_var
if
scale
is
not
None
:
scale
=
reshape
(
scale
,
x
.
shape
[
begin_norm_axis
:])
out
=
t7
*
broadcast_to
(
scale
,
out
.
shape
)
if
bias
is
not
None
:
bias
=
reshape
(
bias
,
x
.
shape
[
begin_norm_axis
:])
out
=
out
+
broadcast_to
(
bias
,
out
.
shape
)
return
out
,
mean_
,
variance
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录