Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
269470d6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
269470d6
编写于
11月 19, 2020
作者:
L
liym27
提交者:
GitHub
11月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dynamic-to-Static] Remove unnecessary variables of the arguments in true_func/false_func (#28722)
上级
7d32e100
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
57 addition
and
10 deletion
+57
-10
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
...dle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
+40
-10
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
...d/tests/unittests/dygraph_to_static/ifelse_simple_func.py
+11
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
...le/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
+6
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
浏览文件 @
269470d6
...
...
@@ -245,23 +245,51 @@ def get_name_ids(nodes, end_node=None):
return
name_visitor
.
name_ids
def
parse_cond_args
(
var_ids_dict
,
return_ids
=
None
,
ctx
=
gast
.
Load
):
def
parse_cond_args
(
parent_ids_dict
,
var_ids_dict
,
modified_ids_dict
=
None
,
ctx
=
gast
.
Load
):
"""
Find out the ast.Name.id list of input by analyzing node's AST information.
"""
name_ids
=
[
# 1. filter the var fit the ctx
arg_name_ids
=
[
var_id
for
var_id
,
var_ctx
in
six
.
iteritems
(
var_ids_dict
)
if
isinstance
(
var_ctx
[
0
],
ctx
)
]
if
return_ids
:
new_args
=
set
(
return_ids
)
-
set
(
name_ids
)
name_ids
.
extend
(
list
(
new_args
))
name_ids
.
sort
()
# 2. args should contain modified var ids in if-body or else-body
# case:
#
# ```
# if b < 1:
# z = y
# else:
# z = x
# ```
#
# In the above case, `z` should be in the args of cond()
if
modified_ids_dict
:
arg_name_ids
=
set
(
arg_name_ids
)
|
set
(
modified_ids_dict
)
# 3. args should not contain the vars not in parent ids
# case :
#
# ```
# x = 1
# if x > y:
# z = [v for v in range(i)]
# ```
#
# In the above case, `v` should not be in the args of cond()
arg_name_ids
=
list
(
set
(
arg_name_ids
)
&
set
(
parent_ids_dict
))
arg_name_ids
.
sort
()
args
=
[
gast
.
Name
(
id
=
name_id
,
ctx
=
gast
.
Load
(),
annotation
=
None
,
type_comment
=
None
)
for
name_id
in
name_ids
for
name_id
in
arg_
name_ids
]
arguments
=
gast
.
arguments
(
args
=
args
,
...
...
@@ -412,7 +440,7 @@ def transform_if_else(node, root):
all_name_ids
=
get_name_ids
([
root
])
for
name
in
all_name_ids
:
before_var_names_ids
=
parent_name_ids
.
get
(
name
,
[])
+
\
body_name_ids
.
get
(
name
,
[])
+
orelse_name_ids
.
get
(
name
,
[])
body_name_ids
.
get
(
name
,
[])
+
orelse_name_ids
.
get
(
name
,
[])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids
=
[
...
...
@@ -444,12 +472,14 @@ def transform_if_else(node, root):
true_func_node
=
create_funcDef_node
(
node
.
body
,
name
=
unique_name
.
generate
(
TRUE_FUNC_PREFIX
),
input_args
=
parse_cond_args
(
body_name_ids
,
modified_name_ids
),
input_args
=
parse_cond_args
(
parent_name_ids
,
body_name_ids
,
modified_name_ids
),
return_name_ids
=
return_name_ids
)
false_func_node
=
create_funcDef_node
(
node
.
orelse
,
name
=
unique_name
.
generate
(
FALSE_FUNC_PREFIX
),
input_args
=
parse_cond_args
(
orelse_name_ids
,
modified_name_ids
),
input_args
=
parse_cond_args
(
parent_name_ids
,
orelse_name_ids
,
modified_name_ids
),
return_name_ids
=
return_name_ids
)
return
create_new_vars_in_parent_stmts
,
true_func_node
,
false_func_node
,
return_name_ids
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
浏览文件 @
269470d6
...
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
...
...
@@ -99,6 +100,16 @@ def dyfunc_with_if_else3(x):
return
x
def
dyfunc_with_if_else_with_list_geneator
(
x
):
if
10
>
5
:
y
=
paddle
.
add_n
(
[
paddle
.
full
(
shape
=
[
2
],
fill_value
=
v
)
for
v
in
range
(
5
)])
else
:
y
=
x
return
y
def
nested_if_else
(
x_v
):
batch_size
=
16
feat_size
=
x_v
.
shape
[
-
1
]
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
浏览文件 @
269470d6
...
...
@@ -69,6 +69,12 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self
.
dyfunc
=
dyfunc_with_if_else3
class
TestDygraphIfElseWithListGenerator
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
self
.
dyfunc
=
dyfunc_with_if_else_with_list_geneator
class
TestDygraphNestedIfElse
(
TestDygraphIfElse
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
random
([
10
,
16
]).
astype
(
'float32'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录