Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
96126532
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
96126532
编写于
11月 24, 2020
作者:
H
Huihuang Zheng
提交者:
GitHub
11月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix Incorrect After Node Vars in IfElseTransformer, test=develop (#28992)
The PR description is long. See details in the PR link.
上级
982fd0f3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
102 addition
and
26 deletion
+102
-26
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
...dle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
+31
-26
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
...le/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
+71
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
浏览文件 @
96126532
...
...
@@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer):
class
NameVisitor
(
gast
.
NodeVisitor
):
def
__init__
(
self
,
end_node
=
None
):
def
__init__
(
self
,
after_node
=
None
,
end_node
=
None
):
# The start node (exclusive) of the visitor
self
.
after_node
=
after_node
# The terminate node of the visitor.
self
.
end_node
=
end_node
# Dict to store the names and ctxs of vars.
self
.
name_ids
=
defaultdict
(
list
)
# List of current visited nodes
self
.
ancestor_nodes
=
[]
#
Available only when end_node is set
.
self
.
_i
s_finished
=
Fals
e
#
True when in range (after_node, end_node)
.
self
.
_i
n_range
=
after_node
is
Non
e
self
.
_candidate_ctxs
=
(
gast
.
Store
,
gast
.
Load
,
gast
.
Param
)
self
.
_def_func_names
=
set
()
def
visit
(
self
,
node
):
"""Visit a node."""
if
node
==
self
.
end_node
or
self
.
_is_finished
:
self
.
_is_finished
=
True
if
self
.
after_node
is
not
None
and
node
==
self
.
after_node
:
self
.
_in_range
=
True
return
if
node
==
self
.
end_node
:
self
.
_in_range
=
False
return
self
.
ancestor_nodes
.
append
(
node
)
...
...
@@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor):
In above two cases, we should consider to manage the scope of vars to parsing
the arguments and returned vars correctly.
"""
if
not
self
.
end_node
:
if
not
self
.
_in_range
or
not
self
.
end_node
:
self
.
generic_visit
(
node
)
return
else
:
before_if_name_ids
=
copy
.
deepcopy
(
self
.
name_ids
)
body_name_ids
=
self
.
_visit_child
(
node
.
body
)
# If traversal process stops early in `if.body`, return the currently seen name_ids.
if
self
.
_is_finished
:
if
not
self
.
_in_range
:
self
.
_update_name_ids
(
before_if_name_ids
)
else
:
else_name_ids
=
self
.
_visit_child
(
node
.
orelse
)
# If traversal process stops early in `if.orelse`, return the currently seen name_ids.
if
self
.
_is_finished
:
if
not
self
.
_in_range
:
self
.
_update_name_ids
(
before_if_name_ids
)
else
:
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
...
...
@@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor):
self
.
name_ids
=
before_if_name_ids
def
visit_Attribute
(
self
,
node
):
if
not
self
.
_is_call_func_name_node
(
node
):
if
not
self
.
_i
n_range
or
not
self
.
_i
s_call_func_name_node
(
node
):
self
.
generic_visit
(
node
)
def
visit_Name
(
self
,
node
):
if
not
self
.
_in_range
:
self
.
generic_visit
(
node
)
return
blacklist
=
{
'True'
,
'False'
,
'None'
}
if
node
.
id
in
blacklist
:
return
if
node
.
id
in
self
.
_def_func_names
:
...
...
@@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor):
self
.
name_ids
[
node
.
id
].
append
(
node
.
ctx
)
def
visit_Assign
(
self
,
node
):
if
not
self
.
_in_range
:
self
.
generic_visit
(
node
)
return
# Visit `value` firstly.
node
.
_fields
=
(
'value'
,
'targets'
)
self
.
generic_visit
(
node
)
def
visit_FunctionDef
(
self
,
node
):
if
not
self
.
_in_range
:
self
.
generic_visit
(
node
)
return
self
.
_def_func_names
.
add
(
node
.
name
)
if
not
self
.
end_node
:
self
.
generic_visit
(
node
)
...
...
@@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor):
self
.
name_ids
=
defaultdict
(
list
)
self
.
generic_visit
(
node
)
if
self
.
_is_finished
:
if
not
self
.
_in_range
:
self
.
_update_name_ids
(
before_name_ids
)
else
:
self
.
name_ids
=
before_name_ids
...
...
@@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor):
self
.
name_ids
[
name_id
]
=
ctxs
+
self
.
name_ids
[
name_id
]
def
get_name_ids
(
nodes
,
end_node
=
None
):
def
get_name_ids
(
nodes
,
after_node
=
None
,
end_node
=
None
):
"""
Return all ast.Name.id of python variable in nodes.
Return all ast.Name.id of python variable in nodes range from
(after_node, end_node) exclusively. If after_node or end_node is None, the
range is unlimited.
"""
name_visitor
=
NameVisitor
(
end_node
)
name_visitor
=
NameVisitor
(
after_node
,
end_node
)
for
node
in
nodes
:
name_visitor
.
visit
(
node
)
return
name_visitor
.
name_ids
...
...
@@ -434,20 +451,8 @@ def transform_if_else(node, root):
parent_name_ids
=
get_name_ids
([
root
],
end_node
=
node
)
body_name_ids
=
get_name_ids
(
node
.
body
)
orelse_name_ids
=
get_name_ids
(
node
.
orelse
)
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
after_ifelse_name_ids
=
defaultdict
(
list
)
all_name_ids
=
get_name_ids
([
root
])
for
name
in
all_name_ids
:
before_var_names_ids
=
parent_name_ids
.
get
(
name
,
[])
+
\
body_name_ids
.
get
(
name
,
[])
+
orelse_name_ids
.
get
(
name
,
[])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids
=
[
ctx
for
ctx
in
all_name_ids
[
name
]
if
ctx
not
in
before_var_names_ids
]
if
after_var_names_ids
:
after_ifelse_name_ids
[
name
]
=
after_var_names_ids
after_ifelse_name_ids
=
get_name_ids
([
root
],
after_node
=
node
)
return_name_ids
,
modified_name_ids_from_parent
,
new_vars_to_create
=
parse_cond_return
(
parent_name_ids
,
body_name_ids
,
orelse_name_ids
,
after_ifelse_name_ids
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
浏览文件 @
96126532
...
...
@@ -17,6 +17,7 @@ from __future__ import print_function
import
numpy
as
np
import
unittest
import
paddle
from
paddle.fluid.dygraph.jit
import
declarative
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
ProgramTranslator
...
...
@@ -271,5 +272,75 @@ class TestNetWithExternalFunc(TestDygraphIfElseNet):
self
.
Net
=
NetWithExternalFunc
class
DiffModeNet1
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
mode
):
super
(
DiffModeNet1
,
self
).
__init__
()
self
.
mode
=
mode
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
,
y
):
if
self
.
mode
==
'train'
:
out
=
x
+
y
elif
self
.
mode
==
'infer'
:
out
=
x
-
y
else
:
raise
ValueError
(
'Illegal mode'
)
return
out
class
DiffModeNet2
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
mode
):
super
(
DiffModeNet2
,
self
).
__init__
()
self
.
mode
=
mode
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
,
y
):
if
self
.
mode
==
'train'
:
out
=
x
+
y
return
out
elif
self
.
mode
==
'infer'
:
out
=
x
-
y
return
out
else
:
raise
ValueError
(
'Illegal mode'
)
class
TestDiffModeNet
(
unittest
.
TestCase
):
"""
TestCase for the net with different modes
"""
def
setUp
(
self
):
self
.
x
=
paddle
.
randn
([
10
,
16
],
'float32'
)
self
.
y
=
paddle
.
randn
([
10
,
16
],
'float32'
)
self
.
init_net
()
def
init_net
(
self
):
self
.
Net
=
DiffModeNet1
def
_run
(
self
,
mode
,
to_static
):
prog_trans
=
ProgramTranslator
()
prog_trans
.
enable
(
to_static
)
net
=
self
.
Net
(
mode
)
ret
=
net
(
self
.
x
,
self
.
y
)
return
ret
.
numpy
()
def
test_train_mode
(
self
):
self
.
assertTrue
((
self
.
_run
(
mode
=
'train'
,
to_static
=
True
)
==
self
.
_run
(
mode
=
'train'
,
to_static
=
False
)).
all
())
def
test_infer_mode
(
self
):
self
.
assertTrue
((
self
.
_run
(
mode
=
'infer'
,
to_static
=
True
)
==
self
.
_run
(
mode
=
'infer'
,
to_static
=
False
)).
all
())
class
TestDiffModeNet2
(
TestDiffModeNet
):
def
init_net
(
self
):
self
.
Net
=
DiffModeNet2
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录