Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a8dd425a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a8dd425a
编写于
2月 13, 2020
作者:
H
Huihuang Zheng
提交者:
GitHub
2月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Static Analysis to Construct AstNodeWrapper (#22569)
As the title
上级
146ed409
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
114 addition
and
8 deletion
+114
-8
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+48
-8
python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py
...d/tests/unittests/test_ast_transformer_static_analysis.py
+66
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
a8dd425a
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
ast
__all__
=
[
'
DygraphToStaticAst
'
]
__all__
=
[
'
AstNodeWrapper'
,
'DygraphToStaticAst'
,
'StaticAnalysisVisitor
'
]
class
NodeVarType
(
object
):
...
...
@@ -51,9 +51,54 @@ class AstNodeWrapper(object):
def
__init__
(
self
,
node
):
self
.
node
=
node
self
.
parent
=
None
self
.
children
=
[]
self
.
node_var_type
=
NodeVarType
.
UNKNOWN
class
StaticAnalysisVisitor
(
object
):
"""
A class that does static analysis
"""
def
__init__
(
self
,
ast_root
=
None
):
if
ast_root
is
not
None
:
self
.
run
(
ast_root
)
def
run
(
self
,
ast_root
):
self
.
node_wrapper_root
=
None
self
.
ancestor_wrappers
=
[]
self
.
node_to_wrapper_map
=
{}
self
.
dfs_visit
(
ast_root
)
def
dfs_visit
(
self
,
node
):
# AST reuses some ast.nodes, such as Param node of expr_context
if
node
not
in
self
.
node_to_wrapper_map
:
cur_wrapper
=
AstNodeWrapper
(
node
)
self
.
node_to_wrapper_map
[
node
]
=
cur_wrapper
else
:
cur_wrapper
=
self
.
node_to_wrapper_map
[
node
]
if
self
.
node_wrapper_root
is
None
:
self
.
node_wrapper_root
=
cur_wrapper
if
len
(
self
.
ancestor_wrappers
)
!=
0
:
last_wrapper
=
self
.
ancestor_wrappers
[
-
1
]
last_wrapper
.
children
.
append
(
cur_wrapper
)
cur_wrapper
.
parent
=
last_wrapper
self
.
ancestor_wrappers
.
append
(
cur_wrapper
)
for
child
in
ast
.
iter_child_nodes
(
node
):
self
.
dfs_visit
(
child
)
self
.
ancestor_wrappers
.
pop
()
return
cur_wrapper
.
node_var_type
def
get_node_wrapper_root
(
self
):
return
self
.
node_wrapper_root
def
get_node_to_wrapper_map
(
self
):
return
self
.
node_to_wrapper_map
class
DygraphToStaticAst
(
ast
.
NodeTransformer
):
"""
Main class to transform Dygraph to Static Graph
...
...
@@ -62,15 +107,10 @@ class DygraphToStaticAst(ast.NodeTransformer):
def
get_static_ast
(
self
,
root
):
# save root for some analysis may need global AST
self
.
root
=
root
self
.
static_analysis_root
=
AstNodeWrapper
(
root
)
self
.
visit
(
root
)
self
.
static_analysis_root
=
StaticAnalysisVisitor
(
root
).
get_node_wrapper_root
(
)
self
.
transfer_from_node_type
(
self
.
static_analysis_root
)
return
self
.
static_analysis_root
def
visit
(
self
,
node
):
# TODO construct a tree whose nodes are AstNodeWrapper
# This step also does static node type analysis
print
(
"Not implemented"
)
def
transfer_from_node_type
(
self
,
node
):
print
(
"Not implemented"
)
python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py
0 → 100644
浏览文件 @
a8dd425a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
ast
import
inspect
import
unittest
from
paddle.fluid.dygraph.dygraph_to_static
import
AstNodeWrapper
,
StaticAnalysisVisitor
def
func_to_test_1
(
a
,
b
):
return
a
+
b
def
func_to_test_2
(
x
):
for
i
in
range
(
10
):
x
+=
i
m
=
3
while
m
<
8
:
m
+=
1
if
x
<
0
:
return
0
else
:
return
x
class
TestStaticAnalysis
(
unittest
.
TestCase
):
def
_check_wrapper
(
self
,
wrapper
,
node_to_wrapper_map
):
self
.
assertEqual
(
node_to_wrapper_map
[
wrapper
.
node
],
wrapper
)
if
wrapper
.
parent
is
not
None
:
self
.
assertTrue
(
wrapper
in
wrapper
.
parent
.
children
)
children_ast_nodes
=
[
child
for
child
in
ast
.
iter_child_nodes
(
wrapper
.
node
)
]
self
.
assertEqual
(
len
(
wrapper
.
children
),
len
(
children_ast_nodes
))
for
child
in
wrapper
.
children
:
self
.
assertTrue
(
child
.
node
in
children_ast_nodes
)
self
.
_check_wrapper
(
child
,
node_to_wrapper_map
)
def
test_construct_node_wrapper
(
self
):
for
func
in
[
func_to_test_1
,
func_to_test_2
]:
test_source_code
=
inspect
.
getsource
(
func
)
ast_root
=
ast
.
parse
(
test_source_code
)
visitor
=
StaticAnalysisVisitor
(
ast_root
)
wrapper_root
=
visitor
.
get_node_wrapper_root
()
node_to_wrapper_map
=
visitor
.
get_node_to_wrapper_map
()
self
.
_check_wrapper
(
wrapper_root
,
node_to_wrapper_map
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录