Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ba65e4eb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ba65e4eb
编写于
3月 10, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support Tensor.shape in control_flow_if test=develop (#22916)
上级
d33c4343
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
126 addition
and
23 deletion
+126
-23
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+5
-4
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
+58
-11
python/paddle/fluid/tests/unittests/test_ast_util.py
python/paddle/fluid/tests/unittests/test_ast_util.py
+63
-8
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
ba65e4eb
...
@@ -46,11 +46,11 @@ class IfElseTransformer(gast.NodeTransformer):
...
@@ -46,11 +46,11 @@ class IfElseTransformer(gast.NodeTransformer):
wrapper_root
,
AstNodeWrapper
wrapper_root
,
AstNodeWrapper
),
"Type of input node should be AstNodeWrapper, but received %s ."
%
type
(
),
"Type of input node should be AstNodeWrapper, but received %s ."
%
type
(
wrapper_root
)
wrapper_root
)
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
self
.
root
=
wrapper_root
.
node
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
self
.
new_func_nodes
=
{}
self
.
new_func_nodes
=
{}
def
ast_visit
(
self
):
def
transform
(
self
):
"""
"""
Main function to transform AST.
Main function to transform AST.
"""
"""
...
@@ -59,7 +59,8 @@ class IfElseTransformer(gast.NodeTransformer):
...
@@ -59,7 +59,8 @@ class IfElseTransformer(gast.NodeTransformer):
def
visit_If
(
self
,
node
):
def
visit_If
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
If
)
assert
isinstance
(
node
,
gast
.
If
)
need_transform
=
is_control_flow_if
(
node
.
test
)
need_transform
=
is_control_flow_if
(
node
.
test
,
self
.
static_analysis_visitor
)
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
need_transform
:
if
need_transform
:
pred_node
=
node
.
test
pred_node
=
node
.
test
...
@@ -143,7 +144,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -143,7 +144,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self
.
feed_name_to_arg_name
=
basic_api_trans
.
get_feed_name_to_arg_id
()
self
.
feed_name_to_arg_name
=
basic_api_trans
.
get_feed_name_to_arg_id
()
# Transform all if/else statement of Dygraph into Static Graph.
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer
(
node_wrapper
).
ast_visit
()
IfElseTransformer
(
node_wrapper
).
transform
()
LoopTransformer
(
node_wrapper
).
transform
()
LoopTransformer
(
node_wrapper
).
transform
()
...
...
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
浏览文件 @
ba65e4eb
...
@@ -26,6 +26,8 @@ import atexit
...
@@ -26,6 +26,8 @@ import atexit
from
collections
import
defaultdict
from
collections
import
defaultdict
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
NodeVarType
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
TRUE_FUNC_PREFIX
=
'true_fn'
TRUE_FUNC_PREFIX
=
'true_fn'
FALSE_FUNC_PREFIX
=
'false_fn'
FALSE_FUNC_PREFIX
=
'false_fn'
...
@@ -49,15 +51,28 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
...
@@ -49,15 +51,28 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
because reshape_op may be called before this statement.
because reshape_op may be called before this statement.
"""
"""
def
__init__
(
self
,
node
):
def
__init__
(
self
,
static_analysis_visitor
):
self
.
node
=
node
self
.
static_analysis_visitor
=
static_analysis_visitor
self
.
node_to_wrapper_map
=
self
.
static_analysis_visitor
.
get_node_to_wrapper_map
(
)
self
.
is_control_flow
=
False
self
.
is_control_flow
=
False
def
ast_visit
(
self
):
def
transform
(
self
,
node
):
self
.
visit
(
self
.
node
)
if
self
.
_is_candidate_node
(
node
):
self
.
visit
(
node
)
return
self
.
is_control_flow
return
self
.
is_control_flow
def
visit_BoolOp
(
self
,
node
):
for
child
in
node
.
values
:
if
not
self
.
_is_candidate_node
(
child
):
continue
self
.
generic_visit
(
node
)
return
node
def
visit_Compare
(
self
,
node
):
def
visit_Compare
(
self
,
node
):
# Ignores child node with `if x` or `if x is None`
if
not
self
.
_compare_with_none
(
node
):
self
.
generic_visit
(
node
)
for
child
in
gast
.
walk
(
node
):
for
child
in
gast
.
walk
(
node
):
if
isinstance
(
child
,
gast
.
Subscript
):
if
isinstance
(
child
,
gast
.
Subscript
):
self
.
_visit_Subscript
(
child
)
self
.
_visit_Subscript
(
child
)
...
@@ -65,7 +80,7 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
...
@@ -65,7 +80,7 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
def
_visit_Subscript
(
self
,
node
):
def
_visit_Subscript
(
self
,
node
):
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
isinstance
(
node
.
value
,
gast
.
Call
):
if
hasattr
(
node
,
'value'
)
and
isinstance
(
node
.
value
,
gast
.
Call
):
self
.
_visit_Call
(
node
.
value
)
self
.
_visit_Call
(
node
.
value
)
return
node
return
node
...
@@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
...
@@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
assert
isinstance
(
node
,
gast
.
Call
)
assert
isinstance
(
node
,
gast
.
Call
)
if
isinstance
(
node
.
func
,
gast
.
Attribute
):
if
isinstance
(
node
.
func
,
gast
.
Attribute
):
attr_node
=
node
.
func
attr_node
=
node
.
func
self
.
is_control_flow
=
(
attr_node
.
attr
==
'numpy'
)
if
attr_node
.
attr
==
'numpy'
:
self
.
is_control_flow
=
True
def
visit_Call
(
self
,
node
):
if
is_paddle_api
(
node
):
self
.
is_control_flow
=
True
return
node
def
visit_Name
(
self
,
node
):
wrapper_node
=
self
.
node_to_wrapper_map
.
get
(
node
,
None
)
if
wrapper_node
is
not
None
:
if
wrapper_node
.
node_var_type
&
{
NodeVarType
.
TENSOR
,
NodeVarType
.
PADDLE_RETURN_TYPES
}:
self
.
is_control_flow
=
True
return
node
def
_is_candidate_node
(
self
,
node
):
return
isinstance
(
node
,
(
gast
.
Compare
,
gast
.
BoolOp
))
def
_compare_with_none
(
self
,
node
):
if
isinstance
(
node
,
gast
.
Compare
):
for
child
in
[
node
.
left
,
node
.
comparators
]:
# node.comparators is a list.
if
isinstance
(
child
,
list
):
child
=
child
[
0
]
if
(
isinstance
(
child
,
gast
.
Constant
)
and
child
.
value
is
None
)
or
(
isinstance
(
child
,
gast
.
Name
)
and
child
.
id
==
'None'
):
return
True
return
False
def
is_control_flow_if
(
node
):
def
is_control_flow_if
(
node
,
static_analysis_visitor
=
None
):
"""
"""
Determine whether the node is a plain python `if statement` or
Determine whether the node is a plain python `if statement` or
control flow in Paddle.
control flow in Paddle.
...
@@ -84,7 +129,9 @@ def is_control_flow_if(node):
...
@@ -84,7 +129,9 @@ def is_control_flow_if(node):
assert
isinstance
(
assert
isinstance
(
node
,
gast
.
AST
node
,
gast
.
AST
),
"Type of input node should be gast.AST, but received %s."
%
type
(
node
)
),
"Type of input node should be gast.AST, but received %s."
%
type
(
node
)
return
IsControlFlowIfVisitor
(
node
).
ast_visit
()
if
static_analysis_visitor
is
None
:
static_analysis_visitor
=
StaticAnalysisVisitor
(
node
)
return
IsControlFlowIfVisitor
(
static_analysis_visitor
).
transform
(
node
)
def
get_name_ids
(
nodes
,
not_name_set
=
None
,
node_black_list
=
None
):
def
get_name_ids
(
nodes
,
not_name_set
=
None
,
node_black_list
=
None
):
...
...
python/paddle/fluid/tests/unittests/test_ast_util.py
浏览文件 @
ba65e4eb
...
@@ -21,6 +21,7 @@ import inspect
...
@@ -21,6 +21,7 @@ import inspect
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static.ast_utils
import
get_name_ids
,
ast_to_func
,
is_control_flow_if
from
paddle.fluid.dygraph.dygraph_to_static.ast_utils
import
get_name_ids
,
ast_to_func
,
is_control_flow_if
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
from
test_dygraph_to_static_basic
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
from
test_dygraph_to_static_basic
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
...
@@ -98,35 +99,89 @@ class TestIsControlFlowIf(unittest.TestCase):
...
@@ -98,35 +99,89 @@ class TestIsControlFlowIf(unittest.TestCase):
def
test_expr
(
self
):
def
test_expr
(
self
):
# node is not ast.Compare
# node is not ast.Compare
node
=
gast
.
parse
(
"a + b"
)
node
=
gast
.
parse
(
"a + b"
)
self
.
assertFalse
(
is_control_flow_if
(
node
))
self
.
assertFalse
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_expr2
(
self
):
def
test_expr2
(
self
):
node
=
gast
.
parse
(
"a + x.numpy()[1]"
)
node
=
gast
.
parse
(
"a + x.numpy()[1]"
)
self
.
assertFalse
(
is_control_flow_if
(
node
))
self
.
assertFalse
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_is_None
(
self
):
def
test_is_None
(
self
):
node
=
gast
.
parse
(
"x is None"
)
node
=
gast
.
parse
(
"x is None"
)
self
.
assertFalse
(
is_control_flow_if
(
node
))
self
.
assertFalse
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_is_None2
(
self
):
def
test_is_None2
(
self
):
node
=
gast
.
parse
(
"fluid.layers.sum(x) is None"
)
node
=
gast
.
parse
(
"fluid.layers.sum(x) is None"
)
self
.
assertFalse
(
is_control_flow_if
(
node
))
self
.
assertFalse
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_is_None3
(
self
):
def
test_is_None3
(
self
):
node
=
gast
.
parse
(
"fluid.layers.sum(x).numpy() != None"
)
node
=
gast
.
parse
(
"fluid.layers.sum(x).numpy() != None"
)
self
.
assertFalse
(
is_control_flow_if
(
node
))
self
.
assertFalse
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_if
(
self
):
def
test_if
(
self
):
node
=
gast
.
parse
(
"x.numpy()[1] > 1"
)
node
=
gast
.
parse
(
"x.numpy()[1] > 1"
)
self
.
assertTrue
(
is_control_flow_if
(
node
))
self
.
assertTrue
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_if_with_and
(
self
):
def
test_if_with_and
(
self
):
node
=
gast
.
parse
(
"x is not None and 1 < x.numpy()[1]"
)
node
=
gast
.
parse
(
"x is not None and 1 < x.numpy()[1]"
)
self
.
assertTrue
(
is_control_flow_if
(
node
))
self
.
assertTrue
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_if_with_or
(
self
):
def
test_if_with_or
(
self
):
node
=
gast
.
parse
(
"1 < fluid.layers.sum(x).numpy()[2] or x+y < 0"
)
node
=
gast
.
parse
(
"1 < fluid.layers.sum(x).numpy()[2] or x+y < 0"
)
self
.
assertTrue
(
is_control_flow_if
(
node
))
self
.
assertTrue
(
is_control_flow_if
(
node
.
body
[
0
].
value
))
def
test_shape
(
self
):
code
=
"""
def foo(x):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 16:
x = x + 1
return x
"""
code
=
textwrap
.
dedent
(
code
)
node
=
gast
.
parse
(
code
)
visitor
=
StaticAnalysisVisitor
(
node
)
test_node
=
node
.
body
[
0
].
body
[
1
].
test
self
.
assertTrue
(
is_control_flow_if
(
test_node
,
visitor
))
def
test_shape_with_andOr
(
self
):
code
=
"""
def foo(x):
batch_size = fluid.layers.shape(x)
if x is not None and batch_size[0] > 16 or 2 > 1:
x = x + 1
return x
"""
code
=
textwrap
.
dedent
(
code
)
node
=
gast
.
parse
(
code
)
visitor
=
StaticAnalysisVisitor
(
node
)
test_node
=
node
.
body
[
0
].
body
[
1
].
test
self
.
assertTrue
(
is_control_flow_if
(
test_node
,
visitor
))
def
test_paddle_api
(
self
):
code
=
"""
def foo(x):
if fluid.layers.shape(x)[0] > 16:
x = x + 1
return x
"""
code
=
textwrap
.
dedent
(
code
)
node
=
gast
.
parse
(
code
)
visitor
=
StaticAnalysisVisitor
(
node
)
test_node
=
node
.
body
[
0
].
body
[
0
].
test
self
.
assertTrue
(
is_control_flow_if
(
test_node
,
visitor
))
def
test_paddle_api_with_andOr
(
self
):
code
=
"""
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code
=
textwrap
.
dedent
(
code
)
node
=
gast
.
parse
(
code
)
visitor
=
StaticAnalysisVisitor
(
node
)
test_node
=
node
.
body
[
0
].
body
[
0
].
test
self
.
assertTrue
(
is_control_flow_if
(
test_node
,
visitor
))
def
test_raise_error
(
self
):
def
test_raise_error
(
self
):
node
=
"a + b"
node
=
"a + b"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录