Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1950a360
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1950a360
编写于
6月 14, 2022
作者:
W
WangZhen
提交者:
GitHub
6月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2St]Refine ifelse early return (#43328)
* Refine ifelse early return
上级
083d769b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
149 addition
and
44 deletion
+149
-44
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+2
-0
python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py
...uid/dygraph/dygraph_to_static/early_return_transformer.py
+88
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
...d/tests/unittests/dygraph_to_static/ifelse_simple_func.py
+24
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
...ts/unittests/dygraph_to_static/test_program_translator.py
+35
-44
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
1950a360
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
# See details in https://github.com/serge-sans-paille/gast/
# See details in https://github.com/serge-sans-paille/gast/
import
os
import
os
from
paddle.utils
import
gast
from
paddle.utils
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.early_return_transformer
import
EarlyReturnTransformer
from
paddle.fluid.dygraph.dygraph_to_static.assert_transformer
import
AssertTransformer
from
paddle.fluid.dygraph.dygraph_to_static.assert_transformer
import
AssertTransformer
from
paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer
import
BasicApiTransformer
from
paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer
import
BasicApiTransformer
from
paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer
import
BreakContinueTransformer
from
paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer
import
BreakContinueTransformer
...
@@ -87,6 +88,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -87,6 +88,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self
.
visit
(
node_wrapper
.
node
)
self
.
visit
(
node_wrapper
.
node
)
transformers
=
[
transformers
=
[
EarlyReturnTransformer
,
BasicApiTransformer
,
# Basic Api
BasicApiTransformer
,
# Basic Api
TensorShapeTransformer
,
# Tensor.shape -> layers.shape(Tensor)
TensorShapeTransformer
,
# Tensor.shape -> layers.shape(Tensor)
ListTransformer
,
# List used in control flow
ListTransformer
,
# List used in control flow
...
...
python/paddle/fluid/dygraph/dygraph_to_static/early_return_transformer.py
0 → 100644
浏览文件 @
1950a360
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddle.utils
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
class
EarlyReturnTransformer
(
gast
.
NodeTransformer
):
"""
Transform if/else return statement of Dygraph into Static Graph.
"""
def
__init__
(
self
,
wrapper_root
):
assert
isinstance
(
wrapper_root
,
AstNodeWrapper
),
"Type of input node should be AstNodeWrapper, but received %s ."
%
type
(
wrapper_root
)
self
.
root
=
wrapper_root
.
node
def
transform
(
self
):
"""
Main function to transform AST.
"""
self
.
visit
(
self
.
root
)
def
is_define_return_in_if
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
If
),
"Type of input node should be gast.If, but received %s ."
%
type
(
node
)
for
child
in
node
.
body
:
if
isinstance
(
child
,
gast
.
Return
):
return
True
return
False
def
visit_block_nodes
(
self
,
nodes
):
result_nodes
=
[]
destination_nodes
=
result_nodes
for
node
in
nodes
:
rewritten_node
=
self
.
visit
(
node
)
if
isinstance
(
rewritten_node
,
(
list
,
tuple
)):
destination_nodes
.
extend
(
rewritten_node
)
else
:
destination_nodes
.
append
(
rewritten_node
)
# append other nodes to if.orelse even though if.orelse is not empty
if
isinstance
(
node
,
gast
.
If
)
and
self
.
is_define_return_in_if
(
node
):
destination_nodes
=
node
.
orelse
# handle stmt like `if/elif/elif`
while
len
(
destination_nodes
)
>
0
and
\
isinstance
(
destination_nodes
[
0
],
gast
.
If
)
and
\
self
.
is_define_return_in_if
(
destination_nodes
[
0
]):
destination_nodes
=
destination_nodes
[
0
].
orelse
return
result_nodes
def
visit_If
(
self
,
node
):
node
.
body
=
self
.
visit_block_nodes
(
node
.
body
)
node
.
orelse
=
self
.
visit_block_nodes
(
node
.
orelse
)
return
node
def
visit_While
(
self
,
node
):
node
.
body
=
self
.
visit_block_nodes
(
node
.
body
)
node
.
orelse
=
self
.
visit_block_nodes
(
node
.
orelse
)
return
node
def
visit_For
(
self
,
node
):
node
.
body
=
self
.
visit_block_nodes
(
node
.
body
)
node
.
orelse
=
self
.
visit_block_nodes
(
node
.
orelse
)
return
node
def
visit_FunctionDef
(
self
,
node
):
node
.
body
=
self
.
visit_block_nodes
(
node
.
body
)
return
node
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
浏览文件 @
1950a360
...
@@ -100,6 +100,30 @@ def dyfunc_with_if_else3(x):
...
@@ -100,6 +100,30 @@ def dyfunc_with_if_else3(x):
return
x
return
x
def
dyfunc_with_if_else_early_return1
():
x
=
paddle
.
to_tensor
([
10
])
if
x
==
0
:
a
=
paddle
.
zeros
([
2
,
2
])
b
=
paddle
.
zeros
([
3
,
3
])
return
a
,
b
a
=
paddle
.
zeros
([
2
,
2
])
+
1
return
a
def
dyfunc_with_if_else_early_return2
():
x
=
paddle
.
to_tensor
([
10
])
if
x
==
0
:
a
=
paddle
.
zeros
([
2
,
2
])
b
=
paddle
.
zeros
([
3
,
3
])
return
a
,
b
elif
x
==
1
:
c
=
paddle
.
zeros
([
2
,
2
])
+
1
d
=
paddle
.
zeros
([
3
,
3
])
+
1
return
c
,
d
e
=
paddle
.
zeros
([
2
,
2
])
+
3
return
e
def
dyfunc_with_if_else_with_list_geneator
(
x
):
def
dyfunc_with_if_else_with_list_geneator
(
x
):
if
10
>
5
:
if
10
>
5
:
y
=
paddle
.
add_n
(
y
=
paddle
.
add_n
(
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
浏览文件 @
1950a360
...
@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.nn import Linear
...
@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.nn import Linear
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
func_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
func_to_source_code
import
paddle.jit.dy2static
as
_jst
import
paddle.jit.dy2static
as
_jst
from
ifelse_simple_func
import
dyfunc_with_if_else
from
ifelse_simple_func
import
dyfunc_with_if_else
,
dyfunc_with_if_else_early_return1
,
dyfunc_with_if_else_early_return2
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
@@ -83,34 +83,22 @@ class StaticCode1():
...
@@ -83,34 +83,22 @@ class StaticCode1():
x_v
=
_jst
.
convert_ifelse
(
x_v
=
_jst
.
convert_ifelse
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_0
,
false_fn_0
,
(
x_v
,
),
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_0
,
false_fn_0
,
(
x_v
,
),
(
x_v
,
))
(
x_v
,
))
__return_0
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
False
)
def
true_fn_1
(
__return_
0
,
__return_
value_0
,
label
,
x_v
):
def
true_fn_1
(
__return_value_0
,
label
,
x_v
):
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
__return_0
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_0
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_value_0
=
loss
__return_value_0
=
loss
return
__return_0
,
__return_value_0
def
false_fn_1
(
__return_0
,
__return_value_0
):
return
__return_0
,
__return_value_0
__return_0
,
__return_value_0
=
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_1
,
false_fn_1
,
(
__return_0
,
__return_value_0
,
label
,
x_v
),
(
__return_0
,
__return_value_0
))
def
true_fn_2
(
__return_0
,
__return_value_0
,
x_v
):
__return_1
=
_jst
.
create_bool_as_type
(
_jst
.
convert_logical_not
(
__return_0
),
True
)
__return_value_0
=
x_v
return
__return_value_0
return
__return_value_0
def
false_fn_2
(
__return_value_0
):
def
false_fn_1
(
__return_value_0
,
label
,
x_v
):
__return_1
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_value_0
=
x_v
return
__return_value_0
return
__return_value_0
__return_value_0
=
_jst
.
convert_ifelse
(
__return_value_0
=
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_1
,
_jst
.
convert_logical_not
(
__return_0
),
true_fn_2
,
false_fn_2
,
false_fn_1
,
(
__return_0
,
__return_value_0
,
x_v
),
(
__return_value_0
,
))
(
__return_value_0
,
label
,
x_v
),
(
__return_value_0
,
label
,
x_v
))
return
__return_value_0
return
__return_value_0
...
@@ -123,45 +111,33 @@ class StaticCode2():
...
@@ -123,45 +111,33 @@ class StaticCode2():
name
=
'__return_value_init_1'
)
name
=
'__return_value_init_1'
)
__return_value_1
=
__return_value_init_1
__return_value_1
=
__return_value_init_1
def
true_fn_
3
(
x_v
):
def
true_fn_
2
(
x_v
):
x_v
=
x_v
-
1
x_v
=
x_v
-
1
return
x_v
return
x_v
def
false_fn_
3
(
x_v
):
def
false_fn_
2
(
x_v
):
x_v
=
x_v
+
1
x_v
=
x_v
+
1
return
x_v
return
x_v
x_v
=
_jst
.
convert_ifelse
(
x_v
=
_jst
.
convert_ifelse
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_
3
,
false_fn_3
,
(
x_v
,
),
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_
2
,
false_fn_2
,
(
x_v
,
),
(
x_v
,
))
(
x_v
,
))
__return_2
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
False
)
def
true_fn_
4
(
__return_2
,
__return_value_1
,
label
,
x_v
):
def
true_fn_
3
(
__return_value_1
,
label
,
x_v
):
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
__return_2
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_2
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_value_1
=
loss
__return_value_1
=
loss
return
__return_2
,
__return_value_1
def
false_fn_4
(
__return_2
,
__return_value_1
):
return
__return_2
,
__return_value_1
__return_2
,
__return_value_1
=
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_4
,
false_fn_4
,
(
__return_2
,
__return_value_1
,
label
,
x_v
),
(
__return_2
,
__return_value_1
))
def
true_fn_5
(
__return_2
,
__return_value_1
,
x_v
):
__return_3
=
_jst
.
create_bool_as_type
(
_jst
.
convert_logical_not
(
__return_2
),
True
)
__return_value_1
=
x_v
return
__return_value_1
return
__return_value_1
def
false_fn_5
(
__return_value_1
):
def
false_fn_3
(
__return_value_1
,
label
,
x_v
):
__return_3
=
_jst
.
create_bool_as_type
(
label
is
not
None
,
True
)
__return_value_1
=
x_v
return
__return_value_1
return
__return_value_1
__return_value_1
=
_jst
.
convert_ifelse
(
__return_value_1
=
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_3
,
_jst
.
convert_logical_not
(
__return_2
),
true_fn_5
,
false_fn_5
,
false_fn_3
,
(
__return_2
,
__return_value_1
,
x_v
),
(
__return_value_1
,
))
(
__return_value_1
,
label
,
x_v
),
(
__return_value_1
,
label
,
x_v
))
return
__return_value_1
return
__return_value_1
...
@@ -358,6 +334,21 @@ class TestFunctionTrainEvalMode(unittest.TestCase):
...
@@ -358,6 +334,21 @@ class TestFunctionTrainEvalMode(unittest.TestCase):
net
.
foo
.
train
()
net
.
foo
.
train
()
class
TestIfElseEarlyReturn
(
unittest
.
TestCase
):
def
test_ifelse_early_return1
(
self
):
answer
=
np
.
zeros
([
2
,
2
])
+
1
static_func
=
paddle
.
jit
.
to_static
(
dyfunc_with_if_else_early_return1
)
out
=
static_func
()
self
.
assertTrue
(
np
.
allclose
(
answer
,
out
.
numpy
()))
def
test_ifelse_early_return2
(
self
):
answer
=
np
.
zeros
([
2
,
2
])
+
3
static_func
=
paddle
.
jit
.
to_static
(
dyfunc_with_if_else_early_return2
)
out
=
static_func
()
self
.
assertTrue
(
np
.
allclose
(
answer
,
out
.
numpy
()))
class
TestRemoveCommentInDy2St
(
unittest
.
TestCase
):
class
TestRemoveCommentInDy2St
(
unittest
.
TestCase
):
def
func_with_comment
(
self
):
def
func_with_comment
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录