Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
55730d95
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
55730d95
编写于
4月 09, 2021
作者:
A
Aurelius84
提交者:
GitHub
4月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat] Support DictCmp and zip grammer (#32159)
* support DictCmp and zip grammar * fix code style
上级
dabaca00
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
94 addition
and
9 deletion
+94
-9
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
...addle/fluid/dygraph/dygraph_to_static/loop_transformer.py
+25
-9
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+35
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_dict.py
+34
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
浏览文件 @
55730d95
...
...
@@ -378,6 +378,21 @@ class NameVisitor(gast.NodeVisitor):
:param loop_node: Current loop node.
"""
def
filter_name_nodes_from
(
root_node
,
target_var_names
):
"""
Filter children with gast.Name type from node.(inclusivly)
"""
name_nodes
=
set
()
if
isinstance
(
root_node
,
gast
.
Name
):
if
node
.
id
in
target_var_names
:
name_nodes
.
add
(
root_node
)
for
child_node
in
gast
.
walk
(
root_node
):
if
isinstance
(
child_node
,
gast
.
Name
):
if
child_node
.
id
in
target_var_names
:
name_nodes
.
add
(
child_node
)
return
name_nodes
vars_of_list_generator
=
set
()
target_vars_of_for_node
=
set
()
...
...
@@ -412,15 +427,16 @@ class NameVisitor(gast.NodeVisitor):
# 1.2 vars from target vars used in elt_node
target_var_names
=
{
var
.
id
for
var
in
target_vars
}
listcomp_node
=
self
.
_get_parent_node
(
parent_node
)
elt_node
=
listcomp_node
.
elt
if
isinstance
(
elt_node
,
gast
.
Name
):
if
elt_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
elt_node
)
for
child_node
in
gast
.
walk
(
elt_node
):
if
isinstance
(
child_node
,
gast
.
Name
):
if
child_node
.
id
in
target_var_names
:
vars_of_list_generator
.
add
(
child_node
)
comp_node
=
self
.
_get_parent_node
(
parent_node
)
elt_nodes
=
[]
if
isinstance
(
comp_node
,
gast
.
ListComp
):
elt_nodes
.
append
(
comp_node
.
elt
)
elif
isinstance
(
comp_node
,
gast
.
DictComp
):
elt_nodes
.
extend
([
comp_node
.
key
,
comp_node
.
value
])
for
node
in
elt_nodes
:
vars_of_list_generator
|=
filter_name_nodes_from
(
node
,
target_var_names
)
# 2. Get target vars or vars from target vars used in for-loop but the for-loop is
# 1) not the "loop_node" itself
...
...
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
55730d95
...
...
@@ -79,6 +79,7 @@ FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TUPLE_INDEX_PREFIX
=
'__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX
=
'__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX
=
'__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX
=
'__for_loop_iter_zip'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
...
...
@@ -1012,6 +1013,9 @@ class ForNodeVisitor(object):
# - for i, x enumerate(var|var.numpy())
# - for x in var
self
.
iter_var_len_name
=
unique_name
.
generate
(
FOR_ITER_VAR_LEN_PREFIX
)
# - created zip to list var : __for_loop_iter_zip_0
self
.
iter_zip_to_list_name
=
unique_name
.
generate
(
FOR_ITER_ZIP_TO_LIST_PREFIX
)
# - var.numpy()/var
# - for x in var|var.numpy()
...
...
@@ -1083,6 +1087,7 @@ class ForNodeVisitor(object):
def
_parse_for_stmts
(
self
):
init_stmts
=
[]
init_stmts
.
extend
(
self
.
_build_iter_node
())
init_stmts
.
append
(
self
.
_build_index_init_node
())
init_stmts
.
append
(
self
.
_build_var_len_assign_node
())
...
...
@@ -1105,6 +1110,7 @@ class ForNodeVisitor(object):
def
_parse_for_enumerate_stmts
(
self
):
init_stmts
=
[]
init_stmts
.
extend
(
self
.
_build_iter_node
())
init_stmts
.
append
(
self
.
_build_index_init_node
())
init_stmts
.
append
(
self
.
_build_var_len_assign_node
())
init_stmts
.
append
(
self
.
_build_enum_init_node
())
...
...
@@ -1163,6 +1169,34 @@ class ForNodeVisitor(object):
return
convert_len_node
def
_build_iter_node
(
self
):
"""
Process special cases for iter_node inclue:
- Case 1 (for zip):
- for i, val in enumerate(zip(x, y)) # original code:
- __for_loop_iter_zip_0 = list(zip(x, y))
- for i, val in enumerate(__for_loop_iter_zip_0)
"""
new_nodes
=
[]
if
isinstance
(
self
.
iter_node
,
gast
.
Call
)
and
isinstance
(
self
.
iter_node
.
func
,
gast
.
Name
):
if
self
.
iter_node
.
func
.
id
==
'zip'
:
iter_var_name
=
ast_to_source_code
(
self
.
iter_node
).
strip
()
zip_to_list_str
=
"{target} = list({value})"
.
format
(
target
=
self
.
iter_zip_to_list_name
,
value
=
iter_var_name
)
zip_to_list_node
=
gast
.
parse
(
zip_to_list_str
).
body
[
0
]
new_nodes
.
append
(
zip_to_list_node
)
self
.
iter_node
=
gast
.
Name
(
id
=
self
.
iter_zip_to_list_name
,
ctx
=
gast
.
Load
(),
annotation
=
None
,
type_comment
=
None
)
return
new_nodes
def
_build_enum_init_node
(
self
):
if
self
.
is_for_enumerate_iter
()
and
self
.
args_length
!=
1
:
init_value_str
=
ast_to_source_code
(
self
.
iter_args
[
1
]).
strip
()
...
...
@@ -1399,6 +1433,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
for
spec
in
src_input_specs
:
if
spec
not
in
desired_input_specs
:
return
False
else
:
for
i
in
range
(
len_specs
):
src_shape
=
src_input_specs
[
i
].
shape
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py
浏览文件 @
55730d95
...
...
@@ -241,5 +241,39 @@ class TestDictPop(TestNetWithDict):
static_result
))
class
TestDictCmpInFor
(
unittest
.
TestCase
):
def
test_with_for
(
self
):
def
func
():
pos
=
[
1
,
3
]
neg
=
[
-
1
,
-
3
]
dict_val
=
{
'minus'
:
0
}
# test `zip` with `for`
for
(
x
,
y
)
in
zip
(
pos
,
neg
):
val
=
x
-
y
dict_val
.
update
(
{
k
:
val
+
dict_val
[
k
]
for
k
,
v
in
dict_val
.
items
()})
return
dict_val
self
.
assertEqual
(
paddle
.
jit
.
to_static
(
func
)()[
'minus'
],
8
)
def
test_with_for_enumerate
(
self
):
def
func
():
pos
=
[
1
,
3
]
neg
=
[
-
1
,
-
3
]
dict_val
=
{
'minus'
:
0
}
# test `zip` with `for`
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
pos
,
neg
)):
val
=
x
-
y
dict_val
.
update
(
{
k
:
val
+
dict_val
[
k
]
for
k
,
v
in
dict_val
.
items
()})
return
dict_val
self
.
assertEqual
(
paddle
.
jit
.
to_static
(
func
)()[
'minus'
],
8
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录