Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4ea95b6f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4ea95b6f
编写于
3月 06, 2020
作者:
L
liym27
提交者:
GitHub
3月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support Tensor.shape in dygraph_to_static (#22830)
* support basic tensor.shape. * Support tensor.shape with dependencies.
上级
1644926a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
242 addition
and
13 deletion
+242
-13
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+129
-11
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
+3
-1
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+9
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
...le/fluid/tests/unittests/dygraph_to_static/test_resnet.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+100
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
4ea95b6f
...
@@ -14,19 +14,21 @@
...
@@ -14,19 +14,21 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
copy
import
inspect
import
textwrap
import
astor
import
astor
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
# See details in https://github.com/serge-sans-paille/gast/
import
gast
import
gast
import
textwrap
import
inspect
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.loop_transformer
import
LoopTransformer
from
paddle.fluid.dygraph.dygraph_to_static.loop_transformer
import
LoopTransformer
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
,
ast_to_func
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
,
ast_to_func
from
.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
from
.static_analysis
import
AstNodeWrapper
,
NodeVarType
,
StaticAnalysisVisitor
from
.utils
import
*
from
.utils
import
*
__all__
=
[
'DygraphToStaticAst'
,
'convert_to_static'
]
__all__
=
[
'DygraphToStaticAst'
,
'convert_to_static'
]
...
@@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer):
def
get_static_ast
(
self
,
root
):
def
get_static_ast
(
self
,
root
):
# save root for some analysis may need global AST
# save root for some analysis may need global AST
self
.
root
=
root
self
.
root
=
root
self
.
static_analysis_root
=
StaticAnalysisVisitor
(
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
root
)
root
).
get_node_wrapper_root
()
self
.
static_analysis_root
=
self
.
static_analysis_visitor
.
get_node_wrapper_root
(
)
self
.
decorate_func_name
=
None
self
.
decorate_func_name
=
None
self
.
arg_name_to_idx
=
{}
self
.
arg_name_to_idx
=
{}
self
.
transfer_from_node_type
(
self
.
static_analysis_root
)
self
.
transfer_from_node_type
(
self
.
static_analysis_root
)
...
@@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
self
.
visit
(
node_wrapper
.
node
)
self
.
visit
(
node_wrapper
.
node
)
# Transform basic api of dygraph to static graph
# Transform basic api of dygraph to static graph
basic_api_trans
=
BasicApiTransformer
(
node_wrapper
)
basic_api_trans
=
BasicApiTransformer
(
node_wrapper
,
self
.
static_analysis_visitor
)
basic_api_trans
.
ast_visit
()
basic_api_trans
.
ast_visit
()
self
.
feed_name_to_arg_name
=
basic_api_trans
.
get_feed_name_to_arg_id
()
self
.
feed_name_to_arg_name
=
basic_api_trans
.
get_feed_name_to_arg_id
()
...
@@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer):
...
@@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer):
Class to transform basic API from dygraph to static graph.
Class to transform basic API from dygraph to static graph.
"""
"""
def
__init__
(
self
,
wrapper_root
):
def
__init__
(
self
,
wrapper_root
,
static_analysis_visitor
):
assert
isinstance
(
assert
isinstance
(
wrapper_root
,
AstNodeWrapper
wrapper_root
,
AstNodeWrapper
),
"Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
),
"Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
self
.
root
=
wrapper_root
.
node
self
.
class_node_dict
=
{}
self
.
class_node_dict
=
{}
# Used for transformation of data feed
self
.
feed_name_to_arg_id
=
{}
self
.
feed_name_to_arg_id
=
{}
self
.
name_to_tensor_shape
=
{}
# Used for transformation of Tensor.shape
self
.
static_analysis_visitor
=
static_analysis_visitor
self
.
node_to_wrapper_map
=
self
.
static_analysis_visitor
.
get_node_to_wrapper_map
(
)
self
.
scope_var_type_dict
=
{}
self
.
_run_static_visitor
()
def
_run_static_visitor
(
self
):
var_env
=
copy
.
deepcopy
(
self
.
static_analysis_visitor
.
get_var_env
())
# TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty
var_env
.
cur_scope
=
var_env
.
cur_scope
.
sub_scopes
[
0
]
self
.
scope_var_type_dict
=
var_env
.
get_scope_var_type
()
def
ast_visit
(
self
):
def
ast_visit
(
self
):
self
.
visit
(
self
.
root
)
self
.
visit
(
self
.
root
)
...
@@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer):
...
@@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer):
if
self
.
_update_class_node_dict
(
node
):
if
self
.
_update_class_node_dict
(
node
):
return
None
return
None
value_node
=
node
.
value
if
self
.
_update_name_to_tensor_shape
(
node
):
for
child_node
in
gast
.
walk
(
value_node
):
return
node
for
child_node
in
gast
.
walk
(
node
.
value
):
if
isinstance
(
child_node
,
gast
.
Call
):
if
isinstance
(
child_node
,
gast
.
Call
):
self
.
_visit_Call
(
child_node
)
self
.
_visit_Call
(
child_node
)
return
node
return
node
def
visit_Expr
(
self
,
node
):
def
visit_Expr
(
self
,
node
):
...
@@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer):
...
@@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer):
return
return
else
:
else
:
self
.
_visit_Call
(
child_node
)
self
.
_visit_Call
(
child_node
)
return
node
def
visit_Attribute
(
self
,
node
):
if
self
.
_used_by_paddle_api
(
node
):
if
self
.
is_tensor_shape
(
node
):
return
create_api_shape_node
(
node
)
return
node
def
visit_Name
(
self
,
node
):
if
node
.
id
in
self
.
name_to_tensor_shape
:
if
self
.
_used_by_paddle_api
(
node
):
tensor_shape_node
=
self
.
name_to_tensor_shape
[
node
.
id
]
if
isinstance
(
tensor_shape_node
,
gast
.
Attribute
):
return
create_api_shape_node
(
tensor_shape_node
)
elif
isinstance
(
tensor_shape_node
,
gast
.
Subscript
):
result_node
=
copy
.
deepcopy
(
tensor_shape_node
)
result_node
.
value
=
create_api_shape_node
(
tensor_shape_node
.
value
)
return
result_node
return
node
return
node
def
_visit_Call
(
self
,
node
):
def
_visit_Call
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Call
)
assert
isinstance
(
node
,
gast
.
Call
)
# Replace API `to_variable` with `fluid.layers.assign`
# Replace API `to_variable` with `fluid.layers.assign`
if
is_to_variable
(
node
):
if
is_to_variable
(
node
):
self
.
_update_feed_dict
(
node
)
self
.
_update_feed_dict
(
node
)
node
=
to_assign_node
(
node
)
node
=
to_assign_node
(
node
)
return
node
return
node
if
is_paddle_api
(
node
):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary
self
.
generic_visit
(
node
)
func_name
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
func
))
func_name
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
func
))
if
self
.
_is_dygraph_forward
(
func_name
):
if
self
.
_is_dygraph_forward
(
func_name
):
class_node
=
self
.
_get_class_node
(
func_name
)
class_node
=
self
.
_get_class_node
(
func_name
)
static_node
=
to_static_ast
(
node
,
class_node
)
static_node
=
to_static_ast
(
node
,
class_node
)
...
@@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer):
...
@@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer):
else
:
else
:
return
node
return
node
def
is_tensor_shape
(
self
,
node
):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
"""
assert
isinstance
(
node
,
gast
.
Attribute
)
if
node
.
attr
!=
'shape'
:
return
False
try
:
value_id
=
node
.
value
.
id
except
AttributeError
:
return
False
if
value_id
in
self
.
name_to_tensor_shape
:
return
True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try
:
var_type_set
=
self
.
scope_var_type_dict
[
value_id
]
except
KeyError
:
return
False
if
NodeVarType
.
NUMPY_NDARRAY
in
var_type_set
:
return
False
if
NodeVarType
.
TENSOR
not
in
var_type_set
and
NodeVarType
.
PADDLE_RETURN_TYPES
not
in
var_type_set
:
return
False
return
True
def
_used_by_paddle_api
(
self
,
node
):
assert
isinstance
(
node
,
(
gast
.
Attribute
,
gast
.
Name
))
wrapper_node
=
self
.
node_to_wrapper_map
.
get
(
node
)
if
not
wrapper_node
:
# Transformed node is not in node_to_wrapper_map
return
False
while
wrapper_node
.
parent
:
parent_node
=
wrapper_node
.
parent
.
node
if
isinstance
(
parent_node
,
gast
.
Call
):
if
is_paddle_api
(
parent_node
):
return
True
else
:
return
False
wrapper_node
=
wrapper_node
.
parent
return
False
def
_is_dygraph_forward
(
self
,
func_id
):
def
_is_dygraph_forward
(
self
,
func_id
):
return
func_id
in
self
.
class_node_dict
return
func_id
in
self
.
class_node_dict
...
@@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer):
...
@@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer):
def
get_feed_name_to_arg_id
(
self
):
def
get_feed_name_to_arg_id
(
self
):
return
self
.
feed_name_to_arg_id
return
self
.
feed_name_to_arg_id
def
_update_name_to_tensor_shape
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node
=
node
.
targets
[
0
]
try
:
target_id
=
target_node
.
id
except
AttributeError
:
return
False
value_node
=
node
.
value
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_tensor_shape
:
self
.
name_to_tensor_shape
[
target_id
]
=
self
.
name_to_tensor_shape
[
value_node
.
id
]
return
True
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
is_tensor_shape
(
value_node
):
# eg: x.shape
self
.
name_to_tensor_shape
[
target_id
]
=
value_node
return
True
if
isinstance
(
value_node
,
gast
.
Subscript
):
if
isinstance
(
value_node
.
value
,
gast
.
Attribute
):
if
self
.
is_tensor_shape
(
value_node
.
value
):
# eg: x.shape[0]
self
.
name_to_tensor_shape
[
target_id
]
=
value_node
return
True
return
False
def
convert_to_static
(
dyfunc
):
def
convert_to_static
(
dyfunc
):
"""
"""
...
...
python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py
浏览文件 @
4ea95b6f
...
@@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
...
@@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
# TODO(Aurelius84): more elegant way to transform ast into callable object
# TODO(Aurelius84): more elegant way to transform ast into callable object
import_str
=
"import paddle
\n
"
\
import_str
=
"import paddle
\n
"
\
"import paddle.fluid as fluid
\n
"
\
"import paddle.fluid as fluid
\n
"
\
"import paddle.fluid.layers as layers
\n
"
"import paddle.fluid.layers as layers
\n
"
\
"import numpy as np
\n
"
\
"import numpy
\n
"
with
f
:
with
f
:
module_name
=
os
.
path
.
basename
(
f
.
name
[:
-
3
])
module_name
=
os
.
path
.
basename
(
f
.
name
[:
-
3
])
f
.
write
(
import_str
)
f
.
write
(
import_str
)
...
...
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
4ea95b6f
...
@@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name):
...
@@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name):
node
.
args
=
[]
node
.
args
=
[]
node
.
keywords
=
added_keywords
+
node
.
keywords
node
.
keywords
=
added_keywords
+
node
.
keywords
def
create_api_shape_node
(
tensor_shape_node
):
assert
isinstance
(
tensor_shape_node
,
gast
.
Attribute
)
api_shape_node
=
gast
.
Call
(
func
=
gast
.
parse
(
'fluid.layers.shape'
).
body
[
0
].
value
,
args
=
[
tensor_shape_node
.
value
],
keywords
=
[])
return
api_shape_node
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
浏览文件 @
4ea95b6f
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
0 → 100644
浏览文件 @
4ea95b6f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_graph
def
dyfunc_tensor_shape_1
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
x
.
shape
)
return
res
def
dyfunc_tensor_shape_2
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
shape
=
x
.
shape
shape2
=
shape
res
=
fluid
.
layers
.
reshape
(
x
,
shape2
)
return
res
def
dyfunc_tensor_shape_3
(
x
):
# Don't transform y.shape because y is numpy.ndarray
x
=
fluid
.
dygraph
.
to_variable
(
x
)
y
=
numpy
.
ones
(
5
)
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
y
.
shape
)
return
res
def
dyfunc_tensor_shape_4
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
(
-
1
,
x
.
shape
[
0
],
len
(
x
.
shape
)))
return
res
def
dyfunc_tensor_shape_5
(
x
):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))`
x
=
fluid
.
dygraph
.
to_variable
(
x
)
s
=
x
.
shape
[
0
]
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
(
-
1
,
s
))
return
res
test_funcs
=
[
dyfunc_tensor_shape_1
,
dyfunc_tensor_shape_2
,
dyfunc_tensor_shape_3
,
dyfunc_tensor_shape_4
,
dyfunc_tensor_shape_5
]
class
TestTensorShape
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
input
=
numpy
.
ones
(
5
).
astype
(
"int32"
)
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
def
get_dygraph_output
(
self
):
with
fluid
.
dygraph
.
guard
():
res
=
self
.
dygraph_func
(
self
.
input
).
numpy
()
return
res
def
get_static_output
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
static_out
=
dygraph_to_static_graph
(
self
.
dygraph_func
)(
self
.
input
)
exe
=
fluid
.
Executor
(
self
.
place
)
static_res
=
exe
.
run
(
main_program
,
fetch_list
=
static_out
)
return
static_res
[
0
]
def
test_transformed_static_result
(
self
):
for
func
in
test_funcs
:
self
.
dygraph_func
=
func
static_res
=
self
.
get_static_output
()
dygraph_res
=
self
.
get_dygraph_output
()
self
.
assertTrue
(
numpy
.
allclose
(
dygraph_res
,
static_res
),
msg
=
'dygraph res is {}
\n
static_res is {}'
.
format
(
dygraph_res
,
static_res
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录