Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a0846b62
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a0846b62
编写于
5月 26, 2020
作者:
L
liym27
提交者:
GitHub
5月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove target vars of gast.For from before_loop_vars or after_loop_vars (#24732)
上级
d15fc95e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
96 addition
and
3 deletion
+96
-3
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
...addle/fluid/dygraph/dygraph_to_static/loop_transformer.py
+59
-3
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
+37
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
浏览文件 @
a0846b62
...
...
@@ -166,13 +166,19 @@ class NameVisitor(gast.NodeVisitor):
in_loop_vars
=
self
.
in_loop_vars
[
node
]
in_loop_name_strs
=
self
.
_var_nodes_to_names
(
in_loop_vars
)
before_loop_body_vars
=
self
.
before_loop_body_vars
[
node
]
before_loop_body_vars
=
self
.
_remove_target_vars_of_for
(
before_loop_body_vars
,
node
)
before_loop_name_strs
=
self
.
_var_nodes_to_names
(
before_loop_body_vars
)
after_loop_vars
=
self
.
current_seen_vars
-
before_loop_body_vars
-
in_loop_vars
after_loop_vars
=
self
.
_remove_target_vars_of_for
(
after_loop_vars
,
node
)
after_loop_name_strs
=
self
.
_var_nodes_to_names
(
after_loop_vars
,
read_context
)
condition_vars
=
self
.
condition_vars
[
node
]
condition_names
=
self
.
_var_nodes_to_names
(
condition_vars
)
write_vars
=
self
.
write_in_loop
[
node
]
write_names
=
self
.
_var_nodes_to_names
(
write_vars
)
...
...
@@ -203,6 +209,7 @@ class NameVisitor(gast.NodeVisitor):
# vars out
loop_var_names
.
add
(
name
)
create_var_names
.
add
(
name
)
return
loop_var_names
,
create_var_names
def
visit_Name
(
self
,
node
):
...
...
@@ -221,8 +228,8 @@ class NameVisitor(gast.NodeVisitor):
self
.
in_loop_vars
[
loop_node
].
add
(
node
)
if
type
(
node
.
ctx
)
in
write_context
:
self
.
write_in_loop
[
loop_node
].
add
(
node
)
if
self
.
in_condition
:
self
.
condition_vars
[
loop_node
].
add
(
node
)
if
self
.
in_condition
:
self
.
condition_vars
[
loop_node
].
add
(
node
)
self
.
generic_visit
(
node
)
def
visit_FunctionDef
(
self
,
node
):
...
...
@@ -309,11 +316,60 @@ class NameVisitor(gast.NodeVisitor):
return
False
def
_is_call_func_name_node
(
self
,
node
):
parent_node
=
self
.
node_to_wrapper_map
[
node
].
parent
.
node
parent_node
=
self
.
_get_parent_node
(
node
)
if
isinstance
(
parent_node
,
gast
.
Call
)
and
parent_node
.
func
==
node
:
return
True
return
False
def
_get_parent_node
(
self
,
node
):
wrapper_node
=
self
.
node_to_wrapper_map
.
get
(
node
)
if
wrapper_node
:
parent_node
=
wrapper_node
.
parent
.
node
return
parent_node
return
None
def
_remove_target_vars_of_for
(
self
,
before_or_after_loop_vars
,
loop_node
):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars
=
set
()
for
name_node
in
before_or_after_loop_vars
:
if
not
isinstance
(
name_node
,
gast
.
Name
):
continue
parent_node
=
self
.
_get_parent_node
(
name_node
)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
if
isinstance
(
parent_node
,
gast
.
Tuple
):
parent_node
=
self
.
_get_parent_node
(
parent_node
)
if
isinstance
(
parent_node
,
gast
.
For
)
and
parent_node
is
not
loop_node
:
target_node
=
parent_node
.
target
if
isinstance
(
target_node
,
gast
.
Tuple
):
target_vars
=
target_node
.
elts
else
:
target_vars
=
[
target_node
]
if
name_node
in
target_vars
:
removed_vars
.
add
(
name_node
)
removed_vars_name_strs
=
{
var
.
id
for
var
in
removed_vars
}
for
var
in
before_or_after_loop_vars
:
if
not
isinstance
(
var
,
gast
.
Name
):
continue
if
var
.
id
in
removed_vars_name_strs
and
var
not
in
self
.
condition_vars
[
loop_node
]:
removed_vars
.
add
(
var
)
return
before_or_after_loop_vars
-
removed_vars
class
LoopTransformer
(
gast
.
NodeTransformer
):
"""
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py
浏览文件 @
a0846b62
...
...
@@ -132,6 +132,19 @@ def var_create_in_for_loop(max_len):
return
ret
def
nested_for_loop_dyfunc
():
two
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
2
,
dtype
=
"int32"
)
three
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
3
,
dtype
=
"int32"
)
for
j
in
range
(
two
):
for
i
in
range
(
10
):
a
=
2
for
i
in
range
(
three
):
b
=
fluid
.
layers
.
zeros
(
shape
=
[
1
],
dtype
=
'float32'
)
return
b
class
TestNameVisitor
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
loop_funcs
=
[
...
...
@@ -142,6 +155,8 @@ class TestNameVisitor(unittest.TestCase):
]
self
.
create_var_names
=
[
set
(),
set
([
"ret"
]),
set
()]
self
.
nested_for_loop_func
=
nested_for_loop_dyfunc
def
test_loop_vars
(
self
):
for
i
in
range
(
len
(
self
.
loop_funcs
)):
func
=
self
.
loop_funcs
[
i
]
...
...
@@ -155,6 +170,28 @@ class TestNameVisitor(unittest.TestCase):
self
.
assertEqual
(
loop_var_names
,
self
.
loop_var_names
[
i
])
self
.
assertEqual
(
create_var_names
,
self
.
create_var_names
[
i
])
def
test_nested_loop_vars
(
self
):
func
=
self
.
nested_for_loop_func
test_func
=
inspect
.
getsource
(
func
)
gast_root
=
gast
.
parse
(
test_func
)
name_visitor
=
NameVisitor
(
gast_root
)
self
.
loop_var_names
=
[
set
([
"j"
,
"two"
]),
set
([
"i"
,
"three"
,
"b"
]),
set
([
"i"
]),
]
self
.
create_var_names
=
[
set
(),
set
([
"b"
]),
set
()]
i
=
0
for
node
in
gast
.
walk
(
gast_root
):
if
isinstance
(
node
,
(
gast
.
While
,
gast
.
For
)):
loop_var_names
,
create_var_names
=
name_visitor
.
get_loop_var_names
(
node
)
# print(loop_var_names)
self
.
assertEqual
(
loop_var_names
,
self
.
loop_var_names
[
i
])
self
.
assertEqual
(
create_var_names
,
self
.
create_var_names
[
i
])
i
+=
1
class
TestTransformWhileLoop
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录