Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
eb1c0901
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eb1c0901
编写于
6月 18, 2020
作者:
L
liym27
提交者:
GitHub
6月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat]Remove unnecessary vars from gast.comprehension in LoopTransformer. (#25094)
上级
a7944904
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
72 addition
and
24 deletion
+72
-24
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
...addle/fluid/dygraph/dygraph_to_static/loop_transformer.py
+56
-20
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
+16
-4
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
浏览文件 @
eb1c0901
...
...
@@ -117,15 +117,16 @@ class NameVisitor(gast.NodeVisitor):
var_node
.
ctx
)
in_loop_vars
=
set
(
in_loop_vars_list
)
in_loop_vars
=
self
.
_remove_unnecessary_vars
(
in_loop_vars
,
node
)
in_loop_name_strs
=
self
.
_var_nodes_to_names
(
in_loop_vars
)
before_loop_body_vars
=
self
.
before_loop_body_vars
[
node
]
before_loop_body_vars
=
self
.
_remove_
target_vars_of_for
(
before_loop_body_vars
=
self
.
_remove_
unnecessary_vars
(
before_loop_body_vars
,
node
)
before_loop_name_strs
=
self
.
_var_nodes_to_names
(
before_loop_body_vars
)
after_loop_vars
=
self
.
current_seen_vars
-
before_loop_body_vars
-
in_loop_vars
after_loop_vars
=
self
.
_remove_
target_vars_of_for
(
after_loop_vars
,
node
)
after_loop_vars
=
self
.
_remove_
unnecessary_vars
(
after_loop_vars
,
node
)
after_loop_name_strs
=
self
.
_var_nodes_to_names
(
after_loop_vars
,
read_context
)
condition_vars
=
self
.
condition_vars
[
node
]
...
...
@@ -138,7 +139,6 @@ class NameVisitor(gast.NodeVisitor):
for
var
in
in_loop_vars
:
wrapper
=
self
.
node_to_wrapper_map
[
var
]
name_to_type
[
self
.
_var_node_to_name
(
var
)]
=
wrapper
.
node_var_type
for
name
in
in_loop_name_strs
:
if
name
in
before_loop_name_strs
:
# If a variable is used in loop and created before loop
...
...
@@ -296,47 +296,83 @@ class NameVisitor(gast.NodeVisitor):
return
parent_node
return
None
def
_remove_
target_vars_of_for
(
self
,
before_or_after_
loop_vars
,
loop_node
):
def
_remove_
unnecessary_vars
(
self
,
loop_vars
,
loop_node
):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node.
1. Remove target vars of gast.For from before_loop_vars or after_loop_vars.
2. Remove vars only in gast.comprehension.
:param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars
=
set
()
for
name_node
in
before_or_after_loop_vars
:
vars_of_list_generator
=
set
()
target_vars_of_for_node
=
set
()
for
name_node
in
loop_vars
:
if
not
isinstance
(
name_node
,
gast
.
Name
):
continue
parent_node
=
self
.
_get_parent_node
(
name_node
)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
# NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple.
# For examples:
# 1) `for i, j in enumerate(x)` has two target vars: i and j
# 2) `[x for x,y in array]` has two target vars: x and y
if
isinstance
(
parent_node
,
gast
.
Tuple
):
parent_node
=
self
.
_get_parent_node
(
parent_node
)
if
isinstance
(
parent_node
,
gast
.
For
)
and
parent_node
is
not
loop_node
:
# 1. Get vars only in gast.comprehension.
# For examples:
# 1) [x for x,y in array] -> x, x, y
# 2) [f(x) for x in array] -> x
# 3) [func(x, y) for x in array] -> x, x
if
isinstance
(
parent_node
,
gast
.
comprehension
):
# 1.1 target vars in list/set comprehensions
target_node
=
parent_node
.
target
if
isinstance
(
target_node
,
gast
.
Tuple
):
target_vars
=
target_node
.
elts
else
:
target_vars
=
[
target_node
]
if
name_node
in
target_vars
:
removed_vars
.
add
(
name_node
)
vars_of_list_generator
=
vars_of_list_generator
|
set
(
target_vars
)
# 1.2 vars from target vars used in elt_node
target_var_names
=
{
var
.
id
for
var
in
target_vars
}
listcomp_node
=
self
.
_get_parent_node
(
parent_node
)
elt_node
=
listcomp_node
.
elt
if
isinstance
(
elt_node
,
gast
.
Name
):
if
elt_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
elt_node
)
for
child_node
in
gast
.
walk
(
elt_node
):
if
isinstance
(
child_node
,
gast
.
Name
):
if
child_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
child_node
)
# 2. Get target vars or vars from target vars used in for-loop.
elif
isinstance
(
parent_node
,
gast
.
For
)
and
parent_node
is
not
loop_node
:
# 2.1 target vars in gast.For node.
target_node
=
parent_node
.
target
if
isinstance
(
target_node
,
gast
.
Tuple
):
target_vars
=
target_node
.
elts
else
:
target_vars
=
[
target_node
]
removed_vars_name_strs
=
{
var
.
id
for
var
in
removed_vars
}
target_vars_of_for_node
=
target_vars_of_for_node
|
set
(
target_vars
)
for
var
in
before_or_after_loop_vars
:
# 2.2 vars from target vars used in for-loop
target_vars_name_strs
=
{
var
.
id
for
var
in
target_vars_of_for_node
}
for
var
in
loop_vars
:
if
not
isinstance
(
var
,
gast
.
Name
):
continue
if
var
.
id
in
removed
_vars_name_strs
and
var
not
in
self
.
condition_vars
[
if
var
.
id
in
target
_vars_name_strs
and
var
not
in
self
.
condition_vars
[
loop_node
]:
removed_vars
.
add
(
var
)
target_vars_of_for_node
.
add
(
var
)
return
before_or_after_loop_vars
-
removed_vars
removed_vars
=
target_vars_of_for_node
|
vars_of_list_generator
return
loop_vars
-
removed_vars
class
LoopTransformer
(
gast
.
NodeTransformer
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
浏览文件 @
eb1c0901
...
...
@@ -169,15 +169,28 @@ def nested_for_loop_dyfunc():
return
b
def
for_loop_dufunc_with_listcomp
(
array
):
a
=
1
for
j
in
range
(
array
):
res
=
[
x
+
a
for
x
in
array
]
res
=
[
i
for
i
in
array
]
x
=
1
b
=
[
i
for
i
in
array
]
print
(
x
)
return
res
class
TestNameVisitor
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
loop_funcs
=
[
while_loop_dyfunc
,
for_loop_dyfunc
,
while_loop_dyfunc_with_none
while_loop_dyfunc
,
for_loop_dyfunc
,
while_loop_dyfunc_with_none
,
for_loop_dufunc_with_listcomp
]
self
.
loop_var_names
=
[
set
([
"i"
,
"x"
]),
set
([
"i"
,
"ret"
,
"max_len"
]),
set
([
"i"
,
"x"
])
set
([
"i"
,
"x"
]),
set
([
"i"
,
"ret"
,
"max_len"
]),
set
([
"i"
,
"x"
]),
set
([
"j"
,
"array"
,
"res"
,
"x"
])
]
self
.
create_var_names
=
[
set
(),
set
([
"ret"
]),
set
()]
self
.
create_var_names
=
[
set
(),
set
([
"ret"
]),
set
()
,
set
([
"res"
,
"x"
])
]
self
.
nested_for_loop_func
=
nested_for_loop_dyfunc
...
...
@@ -211,7 +224,6 @@ class TestNameVisitor(unittest.TestCase):
if
isinstance
(
node
,
(
gast
.
While
,
gast
.
For
)):
loop_var_names
,
create_var_names
=
name_visitor
.
get_loop_var_names
(
node
)
# print(loop_var_names)
self
.
assertEqual
(
loop_var_names
,
self
.
loop_var_names
[
i
])
self
.
assertEqual
(
create_var_names
,
self
.
create_var_names
[
i
])
i
+=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录