Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eb1c0901
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eb1c0901
编写于
6月 18, 2020
作者:
L
liym27
提交者:
GitHub
6月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat]Remove unnecessary vars from gast.comprehension in LoopTransformer. (#25094)
上级
a7944904
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
72 addition
and
24 deletion
+72
-24
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
...addle/fluid/dygraph/dygraph_to_static/loop_transformer.py
+56
-20
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
+16
-4
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
浏览文件 @
eb1c0901
...
...
@@ -117,15 +117,16 @@ class NameVisitor(gast.NodeVisitor):
var_node
.
ctx
)
in_loop_vars
=
set
(
in_loop_vars_list
)
in_loop_vars
=
self
.
_remove_unnecessary_vars
(
in_loop_vars
,
node
)
in_loop_name_strs
=
self
.
_var_nodes_to_names
(
in_loop_vars
)
before_loop_body_vars
=
self
.
before_loop_body_vars
[
node
]
before_loop_body_vars
=
self
.
_remove_
target_vars_of_for
(
before_loop_body_vars
=
self
.
_remove_
unnecessary_vars
(
before_loop_body_vars
,
node
)
before_loop_name_strs
=
self
.
_var_nodes_to_names
(
before_loop_body_vars
)
after_loop_vars
=
self
.
current_seen_vars
-
before_loop_body_vars
-
in_loop_vars
after_loop_vars
=
self
.
_remove_
target_vars_of_for
(
after_loop_vars
,
node
)
after_loop_vars
=
self
.
_remove_
unnecessary_vars
(
after_loop_vars
,
node
)
after_loop_name_strs
=
self
.
_var_nodes_to_names
(
after_loop_vars
,
read_context
)
condition_vars
=
self
.
condition_vars
[
node
]
...
...
@@ -138,7 +139,6 @@ class NameVisitor(gast.NodeVisitor):
for
var
in
in_loop_vars
:
wrapper
=
self
.
node_to_wrapper_map
[
var
]
name_to_type
[
self
.
_var_node_to_name
(
var
)]
=
wrapper
.
node_var_type
for
name
in
in_loop_name_strs
:
if
name
in
before_loop_name_strs
:
# If a variable is used in loop and created before loop
...
...
@@ -296,47 +296,83 @@ class NameVisitor(gast.NodeVisitor):
return
parent_node
return
None
def
_remove_
target_vars_of_for
(
self
,
before_or_after_
loop_vars
,
loop_node
):
def
_remove_
unnecessary_vars
(
self
,
loop_vars
,
loop_node
):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
2. Remove vars only in gast.comprehension.
:param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars
=
set
()
for
name_node
in
before_or_after_loop_vars
:
vars_of_list_generator
=
set
()
target_vars_of_for_node
=
set
()
for
name_node
in
loop_vars
:
if
not
isinstance
(
name_node
,
gast
.
Name
):
continue
parent_node
=
self
.
_get_parent_node
(
name_node
)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
# NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
# For examples:
# 1) `for i, j in enumerate(x)` has two target vars: i and j
# 2) `[x for x,y in array]` has two target vars: x and y
if
isinstance
(
parent_node
,
gast
.
Tuple
):
parent_node
=
self
.
_get_parent_node
(
parent_node
)
if
isinstance
(
parent_node
,
gast
.
For
)
and
parent_node
is
not
loop_node
:
# 1. Get vars only in gast.comprehension.
# For examples:
# 1) [x for x,y in array] -> x, x, y
# 2) [f(x) for x in array] -> x
# 3) [func(x, y) for x in array] -> x, x
if
isinstance
(
parent_node
,
gast
.
comprehension
):
# 1.1 target vars in list/set comprehensions
target_node
=
parent_node
.
target
if
isinstance
(
target_node
,
gast
.
Tuple
):
target_vars
=
target_node
.
elts
else
:
target_vars
=
[
target_node
]
if
name_node
in
target_vars
:
removed_vars
.
add
(
name_node
)
vars_of_list_generator
=
vars_of_list_generator
|
set
(
target_vars
)
# 1.2 vars from target vars used in elt_node
target_var_names
=
{
var
.
id
for
var
in
target_vars
}
listcomp_node
=
self
.
_get_parent_node
(
parent_node
)
elt_node
=
listcomp_node
.
elt
if
isinstance
(
elt_node
,
gast
.
Name
):
if
elt_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
elt_node
)
for
child_node
in
gast
.
walk
(
elt_node
):
if
isinstance
(
child_node
,
gast
.
Name
):
if
child_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
child_node
)
# 2. Get target vars or vars from target vars used in for-loop.
elif
isinstance
(
parent_node
,
gast
.
For
)
and
parent_node
is
not
loop_node
:
# 2.1 target vars in gast.For node.
target_node
=
parent_node
.
target
if
isinstance
(
target_node
,
gast
.
Tuple
):
target_vars
=
target_node
.
elts
else
:
target_vars
=
[
target_node
]
removed_vars_name_strs
=
{
var
.
id
for
var
in
removed_vars
}
target_vars_of_for_node
=
target_vars_of_for_node
|
set
(
target_vars
)
for
var
in
before_or_after_loop_vars
:
# 2.2 vars from target vars used in for-loop
target_vars_name_strs
=
{
var
.
id
for
var
in
target_vars_of_for_node
}
for
var
in
loop_vars
:
if
not
isinstance
(
var
,
gast
.
Name
):
continue
if
var
.
id
in
removed
_vars_name_strs
and
var
not
in
self
.
condition_vars
[
if
var
.
id
in
target
_vars_name_strs
and
var
not
in
self
.
condition_vars
[
loop_node
]:
removed_vars
.
add
(
var
)
target_vars_of_for_node
.
add
(
var
)
return
before_or_after_loop_vars
-
removed_vars
removed_vars
=
target_vars_of_for_node
|
vars_of_list_generator
return
loop_vars
-
removed_vars
class
LoopTransformer
(
gast
.
NodeTransformer
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
浏览文件 @
eb1c0901
...
...
@@ -169,15 +169,28 @@ def nested_for_loop_dyfunc():
return
b
def
for_loop_dufunc_with_listcomp
(
array
):
a
=
1
for
j
in
range
(
array
):
res
=
[
x
+
a
for
x
in
array
]
res
=
[
i
for
i
in
array
]
x
=
1
b
=
[
i
for
i
in
array
]
print
(
x
)
return
res
class
TestNameVisitor
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
loop_funcs
=
[
while_loop_dyfunc
,
for_loop_dyfunc
,
while_loop_dyfunc_with_none
while_loop_dyfunc
,
for_loop_dyfunc
,
while_loop_dyfunc_with_none
,
for_loop_dufunc_with_listcomp
]
self
.
loop_var_names
=
[
set
([
"i"
,
"x"
]),
set
([
"i"
,
"ret"
,
"max_len"
]),
set
([
"i"
,
"x"
])
set
([
"i"
,
"x"
]),
set
([
"i"
,
"ret"
,
"max_len"
]),
set
([
"i"
,
"x"
]),
set
([
"j"
,
"array"
,
"res"
,
"x"
])
]
self
.
create_var_names
=
[
set
(),
set
([
"ret"
]),
set
()]
self
.
create_var_names
=
[
set
(),
set
([
"ret"
]),
set
()
,
set
([
"res"
,
"x"
])
]
self
.
nested_for_loop_func
=
nested_for_loop_dyfunc
...
...
@@ -211,7 +224,6 @@ class TestNameVisitor(unittest.TestCase):
if
isinstance
(
node
,
(
gast
.
While
,
gast
.
For
)):
loop_var_names
,
create_var_names
=
name_visitor
.
get_loop_var_names
(
node
)
# print(loop_var_names)
self
.
assertEqual
(
loop_var_names
,
self
.
loop_var_names
[
i
])
self
.
assertEqual
(
create_var_names
,
self
.
create_var_names
[
i
])
i
+=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录