Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2403362d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2403362d
编写于
3月 18, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
BugFix for parsing Arguments and inserting funcs in IfElseTransormer (#23035)
* Support and/or in controlFlow if test=develop
上级
01ab8a06
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
206 addition
and
59 deletion
+206
-59
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
...dle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
+134
-47
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
...d/tests/unittests/dygraph_to_static/ifelse_simple_func.py
+52
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py
.../fluid/tests/unittests/dygraph_to_static/test_ast_util.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
...le/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
+12
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py
...id/tests/unittests/dygraph_to_static/test_ifelse_basic.py
+7
-11
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
浏览文件 @
2403362d
...
...
@@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer):
"""
self
.
_insert_func_nodes
(
node
)
def
_insert_func_nodes
(
self
,
parent_
node
):
def
_insert_func_nodes
(
self
,
node
):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
...
...
@@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer):
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if
not
(
self
.
new_func_nodes
and
hasattr
(
parent_node
,
'body'
))
:
if
not
self
.
new_func_nodes
:
return
idx
=
len
(
parent_node
.
body
)
-
1
idx
=
-
1
if
isinstance
(
node
,
list
):
idx
=
len
(
node
)
-
1
elif
isinstance
(
node
,
gast
.
AST
):
for
_
,
child
in
gast
.
iter_fields
(
node
):
self
.
_insert_func_nodes
(
child
)
while
idx
>=
0
:
child_node
=
parent_node
.
body
[
idx
]
child_node
=
node
[
idx
]
if
child_node
in
self
.
new_func_nodes
:
parent_node
.
body
[
idx
:
idx
]
=
self
.
new_func_nodes
[
child_node
]
node
[
idx
:
idx
]
=
self
.
new_func_nodes
[
child_node
]
idx
=
idx
+
len
(
self
.
new_func_nodes
[
child_node
])
-
1
del
self
.
new_func_nodes
[
child_node
]
else
:
...
...
@@ -366,51 +371,133 @@ class IfConditionVisitor(object):
return
new_node
,
new_assign_nodes
def
get_name_ids
(
nodes
,
not_name_set
=
None
,
node_black_list
=
None
):
class
NameVisitor
(
gast
.
NodeVisitor
):
def
__init__
(
self
,
node_black_set
=
None
):
# Set of nodes that will not be visited.
self
.
node_black_set
=
node_black_set
or
set
()
# Dict to store the names and ctxs of vars.
self
.
name_ids
=
defaultdict
(
list
)
# List of current visited nodes
self
.
ancestor_nodes
=
[]
# Available only when node_black_set is set.
self
.
_is_finished
=
False
self
.
_candidate_ctxs
=
(
gast
.
Store
,
gast
.
Load
,
gast
.
Param
)
def
visit
(
self
,
node
):
"""Visit a node."""
if
node
in
self
.
node_black_set
or
self
.
_is_finished
:
self
.
_is_finished
=
True
return
self
.
ancestor_nodes
.
append
(
node
)
method
=
'visit_'
+
node
.
__class__
.
__name__
visitor
=
getattr
(
self
,
method
,
self
.
generic_visit
)
ret
=
visitor
(
node
)
self
.
ancestor_nodes
.
pop
()
return
ret
def
visit_If
(
self
,
node
):
"""
For nested `if/else`, the created vars are not always visible for parent node.
In addition, the vars created in `if.body` are not visible for `if.orelse`.
Case 1:
x = 1
if m > 1:
res = new_tensor
res = res + 1 # Error, `res` is not visible here.
Case 2:
if x_tensor > 0:
res = new_tensor
else:
res = res + 1 # Error, `res` is not visible here.
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
before_if_name_ids
=
copy
.
deepcopy
(
self
.
name_ids
)
body_name_ids
=
self
.
_visit_child
(
node
.
body
)
# If the traversal process stops early, just return the name_ids that have been seen.
if
self
.
_is_finished
:
for
name_id
,
ctxs
in
before_if_name_ids
.
items
():
self
.
name_ids
[
name_id
]
=
ctxs
+
self
.
name_ids
[
name_id
]
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
# into name_ids.
else
:
else_name_ids
=
self
.
_visit_child
(
node
.
orelse
)
new_name_ids
=
self
.
_find_new_name_ids
(
body_name_ids
,
else_name_ids
)
for
new_name_id
in
new_name_ids
:
before_if_name_ids
[
new_name_id
].
append
(
gast
.
Store
())
self
.
name_ids
=
before_if_name_ids
def
visit_Attribute
(
self
,
node
):
if
not
self
.
_is_call_func_name_node
(
node
):
self
.
generic_visit
(
node
)
def
visit_Name
(
self
,
node
):
if
not
self
.
_is_call_func_name_node
(
node
):
if
isinstance
(
node
.
ctx
,
self
.
_candidate_ctxs
):
self
.
name_ids
[
node
.
id
].
append
(
node
.
ctx
)
def
visit_Assign
(
self
,
node
):
# Visit `value` firstly.
node
.
_fields
=
(
'value'
,
'targets'
)
self
.
generic_visit
(
node
)
def
visit_Return
(
self
,
node
):
# Ignore the vars in return
return
def
_visit_child
(
self
,
node
):
self
.
name_ids
=
defaultdict
(
list
)
if
isinstance
(
node
,
list
):
for
item
in
node
:
if
isinstance
(
item
,
gast
.
AST
):
self
.
visit
(
item
)
elif
isinstance
(
node
,
gast
.
AST
):
self
.
visit
(
node
)
return
copy
.
deepcopy
(
self
.
name_ids
)
def
_find_new_name_ids
(
self
,
body_name_ids
,
else_name_ids
):
def
is_required_ctx
(
ctxs
,
required_ctx
):
for
ctx
in
ctxs
:
if
isinstance
(
ctx
,
required_ctx
):
return
True
return
False
candidate_name_ids
=
set
(
body_name_ids
.
keys
())
&
set
(
else_name_ids
.
keys
(
))
store_ctx
=
gast
.
Store
new_name_ids
=
set
()
for
name_id
in
candidate_name_ids
:
if
is_required_ctx
(
body_name_ids
[
name_id
],
store_ctx
)
and
is_required_ctx
(
else_name_ids
[
name_id
],
store_ctx
):
new_name_ids
.
add
(
name_id
)
return
new_name_ids
def
_is_call_func_name_node
(
self
,
node
):
if
len
(
self
.
ancestor_nodes
)
>
1
:
assert
self
.
ancestor_nodes
[
-
1
]
==
node
parent_node
=
self
.
ancestor_nodes
[
-
2
]
if
isinstance
(
parent_node
,
gast
.
Call
)
and
parent_node
.
func
==
node
:
return
True
return
False
def
get_name_ids
(
nodes
,
node_black_set
=
None
):
"""
Return all ast.Name.id of python variable in nodes.
"""
if
not
isinstance
(
nodes
,
(
list
,
tuple
,
set
)):
raise
ValueError
(
"nodes must be one of list, tuple, set, but received %s"
%
type
(
nodes
))
if
not_name_set
is
None
:
not_name_set
=
set
()
def
update
(
old_dict
,
new_dict
):
for
k
,
v
in
new_dict
.
items
():
old_dict
[
k
].
extend
(
v
)
name_ids
=
defaultdict
(
list
)
name_visitor
=
NameVisitor
(
node_black_set
)
for
node
in
nodes
:
if
node_black_list
and
node
in
node_black_list
:
break
if
isinstance
(
node
,
gast
.
AST
):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
# 2. api prefix like `fluid` of `fluid.layers.mean`
if
isinstance
(
node
,
gast
.
Return
):
continue
elif
isinstance
(
node
,
gast
.
Call
)
and
isinstance
(
node
.
func
,
gast
.
Name
):
not_name_set
.
add
(
node
.
func
.
id
)
elif
isinstance
(
node
,
gast
.
Attribute
)
and
isinstance
(
node
.
value
,
gast
.
Name
):
not_name_set
.
add
(
node
.
value
.
id
)
if
isinstance
(
node
,
gast
.
Name
)
and
node
.
id
not
in
name_ids
and
node
.
id
not
in
not_name_set
:
if
isinstance
(
node
.
ctx
,
(
gast
.
Store
,
gast
.
Load
,
gast
.
Param
)):
name_ids
[
node
.
id
].
append
(
node
.
ctx
)
else
:
if
isinstance
(
node
,
gast
.
Assign
):
node
=
copy
.
copy
(
node
)
node
.
_fields
=
(
'value'
,
'targets'
)
for
field
,
value
in
gast
.
iter_fields
(
node
):
value
=
value
if
isinstance
(
value
,
list
)
else
[
value
]
update
(
name_ids
,
get_name_ids
(
value
,
not_name_set
,
node_black_list
))
return
name_ids
name_visitor
.
visit
(
node
)
return
name_visitor
.
name_ids
def
parse_cond_args
(
var_ids_dict
,
return_ids
=
None
,
ctx
=
gast
.
Load
):
...
...
@@ -508,7 +595,7 @@ def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids
=
get_name_ids
([
root
],
node_black_
lis
t
=
[
node
])
parent_name_ids
=
get_name_ids
([
root
],
node_black_
se
t
=
[
node
])
if_name_ids
=
get_name_ids
(
node
.
body
)
else_name_ids
=
get_name_ids
(
node
.
orelse
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
浏览文件 @
2403362d
...
...
@@ -65,6 +65,58 @@ def nested_if_else(x_v):
return
y
def
nested_if_else_2
(
x
):
y
=
fluid
.
layers
.
reshape
(
x
,
[
-
1
,
1
])
b
=
2
if
b
<
1
:
# var `z` is not visible for outer scope
z
=
y
x_shape_0
=
x
.
shape
[
0
]
if
x_shape_0
<
1
:
if
fluid
.
layers
.
shape
(
y
).
numpy
()[
0
]
<
1
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
2
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `z` is a new var here.
z
=
y
+
1
else
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
3
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
else
:
res
=
x
return
res
def
nested_if_else_3
(
x
):
y
=
fluid
.
layers
.
reshape
(
x
,
[
-
1
,
1
])
b
=
2
# var `z` is visible for func.body
if
b
<
1
:
z
=
y
else
:
z
=
x
if
b
<
1
:
res
=
x
# var `out` is only visible for current `if`
if
b
>
1
:
out
=
x
+
1
else
:
out
=
x
-
1
else
:
y_shape
=
fluid
.
layers
.
shape
(
y
)
if
y_shape
.
numpy
()[
0
]
<
1
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
2
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `z` is created in above code block.
z
=
y
+
1
else
:
res
=
fluid
.
layers
.
fill_constant
(
value
=
3
,
shape
=
x
.
shape
,
dtype
=
"int32"
)
# `out` is a new var.
out
=
x
+
1
return
res
class
NetWithControlFlowIf
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
hidden_dim
=
16
):
super
(
NetWithControlFlowIf
,
self
).
__init__
()
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ast_util.py
浏览文件 @
2403362d
...
...
@@ -22,7 +22,7 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_func
from
test_basi
c
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
from
ifelse_simple_fun
c
import
dyfunc_with_if_else
,
dyfunc_with_if_else2
,
nested_if_else
class
TestAST2Func
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_
basic
.py
→
python/paddle/fluid/tests/unittests/dygraph_to_static/test_
ifelse
.py
浏览文件 @
2403362d
...
...
@@ -72,6 +72,18 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self
.
dyfunc
=
nested_if_else
class
TestDygraphIfElse4
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
self
.
dyfunc
=
nested_if_else_2
class
TestDygraphIfElse5
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
self
.
dyfunc
=
nested_if_else_3
class
TestDygraphIfElseWithAndOr
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py
浏览文件 @
2403362d
...
...
@@ -65,14 +65,10 @@ class TestGetNameIds2(TestGetNameIds):
return z
"""
self
.
all_name_ids
=
{
'x'
:
[
gast
.
Param
(),
gast
.
Store
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()
],
'a'
:
[
gast
.
Store
(),
gast
.
Load
(),
gast
.
Load
()],
'y'
:
[
gast
.
Param
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()],
'z'
:
[
gast
.
Store
(),
gast
.
Load
(),
gast
.
Store
(),
gast
.
Store
()]
'x'
:
[
gast
.
Param
(),
gast
.
Store
()],
'a'
:
[
gast
.
Store
(),
gast
.
Load
()],
'y'
:
[
gast
.
Param
(),
gast
.
Load
()],
'z'
:
[
gast
.
Store
()]
}
...
...
@@ -87,9 +83,9 @@ class TestGetNameIds3(TestGetNameIds):
return z
"""
self
.
all_name_ids
=
{
'x'
:
[
gast
.
Param
()
,
gast
.
Load
(),
gast
.
Load
(),
gast
.
Load
()
],
'y'
:
[
gast
.
Param
()
,
gast
.
Load
(),
gast
.
Load
()
],
'z'
:
[
gast
.
Store
()
,
gast
.
Store
(),
gast
.
Load
(),
gast
.
Store
()
]
'x'
:
[
gast
.
Param
()],
'y'
:
[
gast
.
Param
()],
'z'
:
[
gast
.
Store
()]
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录