Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2403362d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2403362d
编写于
3月 18, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
BugFix for parsing Arguments and inserting funcs in IfElseTransormer (#23035)
* Support and/or in controlFlow if test=develop
上级
01ab8a06
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
206 addition
and
59 deletion
+206
-59
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
...dle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
+134
-47
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
...d/tests/unittests/dygraph_to_static/ifelse_simple_func.py
+52
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py
.../fluid/tests/unittests/dygraph_to_static/test_ast_util.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
...le/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
+12
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py
...id/tests/unittests/dygraph_to_static/test_ifelse_basic.py
+7
-11
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
浏览文件 @
2403362d
...
...
@@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer):
"""
self
.
_insert_func_nodes
(
node
)
def
_insert_func_nodes
(
self
,
parent_
node
):
def
_insert_func_nodes
(
self
,
node
):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
...
...
@@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer):
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if
not
(
self
.
new_func_nodes
and
hasattr
(
parent_node
,
'body'
))
:
if
not
self
.
new_func_nodes
:
return
idx
=
len
(
parent_node
.
body
)
-
1
idx
=
-
1
if
isinstance
(
node
,
list
):
idx
=
len
(
node
)
-
1
elif
isinstance
(
node
,
gast
.
AST
):
for
_
,
child
in
gast
.
iter_fields
(
node
):
self
.
_insert_func_nodes
(
child
)
while
idx
>=
0
:
child_node
=
parent_node
.
body
[
idx
]
child_node
=
node
[
idx
]
if
child_node
in
self
.
new_func_nodes
:
parent_node
.
body
[
idx
:
idx
]
=
self
.
new_func_nodes
[
child_node
]
node
[
idx
:
idx
]
=
self
.
new_func_nodes
[
child_node
]
idx
=
idx
+
len
(
self
.
new_func_nodes
[
child_node
])
-
1
del
self
.
new_func_nodes
[
child_node
]
else
:
...
...
@@ -366,51 +371,133 @@ class IfConditionVisitor(object):
return
new_node
,
new_assign_nodes
def
get_name_ids
(
nodes
,
not_name_set
=
None
,
node_black_list
=
None
):
class
NameVisitor
(
gast
.
NodeVisitor
):
def
__init__
(
self
,
node_black_set
=
None
):
# Set of nodes that will not be visited.
self
.
node_black_set
=
node_black_set
or
set
()
# Dict to store the names and ctxs of vars.
self
.
name_ids
=
defaultdict
(
list
)
# List of current visited nodes
self
.
ancestor_nodes
=
[]
# Available only when node_black_set is set.
self
.
_is_finished
=
False
self
.
_candidate_ctxs
=
(
gast
.
Store
,
gast
.
Load
,
gast
.
Param
)
def
visit
(
self
,
node
):
"""Visit a node."""
if
node
in
self
.
node_black_set
or
self
.
_is_finished
:
self
.
_is_finished
=
True
return
self
.
ancestor_nodes
.
append
(
node
)
method
=
'visit_'
+
node
.
__class__
.
__name__
visitor
=
getattr
(
self
,
method
,
self
.
generic_visit
)
ret
=
visitor
(
node
)
self
.
ancestor_nodes
.
pop
()
return
ret
def
visit_If
(
self
,
node
):
"""
For nested `if/else`, the created vars are not always visible for parent node.
In addition, the vars created in `if.body` are not visible for `if.orelse`.
Case 1:
x = 1
if m > 1:
res = new_tensor
res = res + 1 # Error, `res` is not visible here.
Case 2:
if x_tensor > 0:
res = new_tensor
else:
res = res + 1 # Error, `res` is not visible here.
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
before_if_name_ids
=
copy
.
deepcopy
(
self
.
name_ids
)
body_name_ids
=
self
.
_visit_child
(
node
.
body
)
# If the traversal process stops early, just return the name_ids that have been seen.
if
self
.
_is_finished
:
for
name_id
,
ctxs
in
before_if_name_ids
.
items
():
self
.
name_ids
[
name_id
]
=
ctxs
+
self
.
name_ids
[
name_id
]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
else
:
else_name_ids
=
self
.
_visit_child
(
node
.
orelse
)
new_name_ids
=
self
.
_find_new_name_ids
(
body_name_ids
,
else_name_ids
)
for
new_name_id
in
new_name_ids
:
before_if_name_ids
[
new_name_id
].
append
(
gast
.
Store
())
self
.
name_ids
=
before_if_name_ids
def
visit_Attribute
(
self
,
node
):
if
not
self
.
_is_call_func_name_node
(
node
):
self
.
generic_visit
(
node
)
def
visit_Name
(
self
,
node
):
if
not
self
.
_is_call_func_name_node
(
node
):
if
isinstance
(
node
.
ctx
,
self
.
_candidate_ctxs
):
self
.
name_ids
[
node
.
id
].
append
(
node
.
ctx
)
def
visit_Assign
(
self
,
node
):
# Visit `value` firstly.
node
.
_fields
=
(
'value'
,
'targets'
)
self
.
generic_visit
(
node
)
def
visit_Return
(
self
,
node
):
# Ignore the vars in return
return
def
_visit_child
(
self
,
node
):
self
.
name_ids
=
defaultdict
(
list
)
if
isinstance
(
node
,
list
):
for
item
in
node
:
if
isinstance
(
item
,
gast
.
AST
):
self
.
visit
(
item
)
elif
isinstance
(
node
,
gast
.
AST
):
self
.
visit
(
node
)
return
copy
.
deepcopy
(
self
.
name_ids
)
def
_find_new_name_ids
(
self
,
body_name_ids
,
else_name_ids
):
def
is_required_ctx
(
ctxs
,
required_ctx
):
for
ctx
in
ctxs
:
if
isinstance
(
ctx
,
required_ctx
):
return
True
return
False
candidate_name_ids
=
set
(
body_name_ids
.
keys
())
&
set
(
else_name_ids
.
keys
(
))
store_ctx
=
gast
.
Store
new_name_ids
=
set
()
for
name_id
in
candidate_name_ids
:
if
is_required_ctx
(
body_name_ids
[
name_id
],
store_ctx
)
and
is_required_ctx
(
else_name_ids
[
name_id
],
store_ctx
):
new_name_ids
.
add
(
name_id
)
return
new_name_ids
def
_is_call_func_name_node
(
self
,
node
):
if
len
(
self
.
ancestor_nodes
)
>
1
:
assert
self
.
ancestor_nodes
[
-
1
]
==
node
parent_node
=
self
.
ancestor_nodes
[
-
2
]
if
isinstance
(
parent_node
,
gast
.
Call
)
and
parent_node
.
func
==
node
:
return
True
return
False
def
get_name_ids
(
nodes
,
node_black_set
=
None
):
"""
Return all ast.Name.id of python variable in nodes.
"""
if
not
isinstance
(
nodes
,
(
list
,
tuple
,
set
)):
raise
ValueError
(
"nodes must be one of list, tuple, set, but received %s"
%
type
(
nodes
))
if
not_name_set
is
None
:
not_name_set
=
set
()
def
update
(
old_dict
,
new_dict
):
for
k
,
v
in
new_dict
.
items
():
old_dict
[
k
].
extend
(
v
)
name_ids
=
defaultdict
(
list
)
name_visitor
=
NameVisitor
(
node_black_set
)
for
node
in
nodes
:
if
node_black_list
and
node
in
node_black_list
:
break
if
isinstance
(
node
,
gast
.
AST
):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
# 2. api prefix like `fluid` of `fluid.layers.mean`
if
isinstance
(
node
,
gast
.
Return
):
continue
elif
isinstance
(
node
,
gast
.
Call
)
and
isinstance
(
node
.
func
,
gast
.
Name
):
not_name_set
.
add
(
node
.
func
.
id
)
elif
isinstance
(
node
,
gast
.
Attribute
)
and
isinstance
(
node
.
value
,
gast
.
Name
):
not_name_set
.
add
(
node
.
value
.
id
)
if
isinstance
(
node
,
gast
.
Name
)
and
node
.
id
not
in
name_ids
and
node
.
id
not
in
not_name_set
:
if
isinstance
(
node
.
ctx
,
(
gast
.
Store
,
gast
.
Load
,
gast
.
Param
)):
name_ids
[
node
.
id
].
append
(
node
.
ctx
)
else
:
if
isinstance
(
node
,
gast
.
Assign
):
node
=
copy
.
copy
(
node
)
node
.
_fields
=
(
'value'
,
'targets'
)
for
field
,
value
in
gast
.
iter_fields
(
node
):
value
=
value
if
isinstance
(
value
,
list
)
else
[
value
]
update
(
name_ids
,
get_name_ids
(
value
,
not_name_set
,
node_black_list
))
return
name_ids
name_visitor
.
visit
(
node
)
return
name_visitor
.
name_ids
def
parse_cond_args
(
var_ids_dict
,
return_ids
=
None
,
ctx
=
gast
.
Load
):
...
...
@@ -508,7 +595,7 @@ def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids
=
get_name_ids
([
root
],
node_black_
lis
t
=
[
node
])
parent_name_ids
=
get_name_ids
([
root
],
node_black_
se
t
=
[
node
])
if_name_ids
=
get_name_ids
(
node
.
body
)
else_name_ids
=
get_name_ids
(
node
.
orelse
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
浏览文件 @
2403362d
...
...
@@ -65,6 +65,58 @@ def nested_if_else(x_v):
return
y
def
nested_if_else_2
(
x
):
y
=
fluid
.
layers
.
reshape
(
x
,
[
-
1
,
1
])
b
=
2
if
b
<
1
:
# var `z` is not visible for outer scope
z
=
y
x_shape_0
=
x
.
shape
[
0
]
if
x_shape_0
<
1
:
if
fluid
.
layers
.
shape
(
y
).
numpy
()[
0
]
<
1
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
2
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `z` is a new var here.
z
=
y
+
1
else
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
3
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
else
:
res
=
x
return
res
def
nested_if_else_3
(
x
):
y
=
fluid
.
layers
.
reshape
(
x
,
[
-
1
,
1
])
b
=
2
# var `z` is visible for func.body
if
b
<
1
:
z
=
y
else
:
z
=
x
if
b
<
1
:
res
=
x
# var `out` is only visible for current `if`
if
b
>
1
:
out
=
x
+
1
else
:
out
=
x
-
1
else
:
y_shape
=
fluid
.
layers
.
shape
(
y
)
if
y_shape
.
numpy
()[
0
]
<
1
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
2
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `z` is created in above code block.
z
=
y
+
1
else
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
3
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `out` is a new var.
out
=
x
+
1
return
res
class
NetWithControlFlowIf
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
hidden_dim
=
16
):
super
(
NetWithControlFlowIf
,
self
).
__init__
()
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py
浏览文件 @
2403362d
...
...
@@ -22,7 +22,7 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_func
from
test_basi
c
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
from
ifelse_simple_fun
c
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
class
TestAST2Func
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_
basic
.py
→
python/paddle/fluid/tests/unittests/dygraph_to_static/test_
ifelse
.py
浏览文件 @
2403362d
...
...
@@ -72,6 +72,18 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self
.
dyfunc
=
nested_if_else
class
TestDygraphIfElse4
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
self
.
dyfunc
=
nested_if_else_2
class
TestDygraphIfElse5
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
self
.
dyfunc
=
nested_if_else_3
class
TestDygraphIfElseWithAndOr
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py
浏览文件 @
2403362d
...
...
@@ -65,14 +65,10 @@ class TestGetNameIds2(TestGetNameIds):
return z
"""
self
.
all_name_ids
=
{
'x'
:
[
gast
.
Param
(),
gast
.
Store
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()
],
'a'
:
[
gast
.
Store
(),
gast
.
Load
(),
gast
.
Load
()],
'y'
:
[
gast
.
Param
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()],
'z'
:
[
gast
.
Store
(),
gast
.
Load
(),
gast
.
Store
(),
gast
.
Store
()]
'x'
:
[
gast
.
Param
(),
gast
.
Store
()],
'a'
:
[
gast
.
Store
(),
gast
.
Load
()],
'y'
:
[
gast
.
Param
(),
gast
.
Load
()],
'z'
:
[
gast
.
Store
()]
}
...
...
@@ -87,9 +83,9 @@ class TestGetNameIds3(TestGetNameIds):
return z
"""
self
.
all_name_ids
=
{
'x'
:
[
gast
.
Param
()
,
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()
],
'y'
:
[
gast
.
Param
()
,
gast
.
Load
(),
gast
.
Load
()
],
'z'
:
[
gast
.
Store
()
,
gast
.
Store
(),
gast
.
Load
(),
gast
.
Store
()
]
'x'
:
[
gast
.
Param
()],
'y'
:
[
gast
.
Param
()],
'z'
:
[
gast
.
Store
()]
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录