未验证 提交 a8dd425a 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Static Analysis to Construct AstNodeWrapper (#22569)

As the title
上级 146ed409
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import ast import ast
__all__ = ['DygraphToStaticAst'] __all__ = ['AstNodeWrapper', 'DygraphToStaticAst', 'StaticAnalysisVisitor']
class NodeVarType(object): class NodeVarType(object):
...@@ -51,9 +51,54 @@ class AstNodeWrapper(object): ...@@ -51,9 +51,54 @@ class AstNodeWrapper(object):
def __init__(self, node): def __init__(self, node):
self.node = node self.node = node
self.parent = None self.parent = None
self.children = []
self.node_var_type = NodeVarType.UNKNOWN self.node_var_type = NodeVarType.UNKNOWN
class StaticAnalysisVisitor(object):
"""
A class that does static analysis
"""
def __init__(self, ast_root=None):
if ast_root is not None:
self.run(ast_root)
def run(self, ast_root):
self.node_wrapper_root = None
self.ancestor_wrappers = []
self.node_to_wrapper_map = {}
self.dfs_visit(ast_root)
def dfs_visit(self, node):
# AST reuses some ast.nodes, such as Param node of expr_context
if node not in self.node_to_wrapper_map:
cur_wrapper = AstNodeWrapper(node)
self.node_to_wrapper_map[node] = cur_wrapper
else:
cur_wrapper = self.node_to_wrapper_map[node]
if self.node_wrapper_root is None:
self.node_wrapper_root = cur_wrapper
if len(self.ancestor_wrappers) != 0:
last_wrapper = self.ancestor_wrappers[-1]
last_wrapper.children.append(cur_wrapper)
cur_wrapper.parent = last_wrapper
self.ancestor_wrappers.append(cur_wrapper)
for child in ast.iter_child_nodes(node):
self.dfs_visit(child)
self.ancestor_wrappers.pop()
return cur_wrapper.node_var_type
def get_node_wrapper_root(self):
return self.node_wrapper_root
def get_node_to_wrapper_map(self):
return self.node_to_wrapper_map
class DygraphToStaticAst(ast.NodeTransformer): class DygraphToStaticAst(ast.NodeTransformer):
""" """
Main class to transform Dygraph to Static Graph Main class to transform Dygraph to Static Graph
...@@ -62,15 +107,10 @@ class DygraphToStaticAst(ast.NodeTransformer): ...@@ -62,15 +107,10 @@ class DygraphToStaticAst(ast.NodeTransformer):
def get_static_ast(self, root): def get_static_ast(self, root):
# save root for some analysis may need global AST # save root for some analysis may need global AST
self.root = root self.root = root
self.static_analysis_root = AstNodeWrapper(root) self.static_analysis_root = StaticAnalysisVisitor(
self.visit(root) root).get_node_wrapper_root()
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root return self.static_analysis_root
def visit(self, node):
# TODO construct a tree whose nodes are AstNodeWrapper
# This step also does static node type analysis
print("Not implemented")
def transfer_from_node_type(self, node): def transfer_from_node_type(self, node):
print("Not implemented") print("Not implemented")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import ast
import inspect
import unittest
from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, StaticAnalysisVisitor
def func_to_test_1(a, b):
return a + b
def func_to_test_2(x):
for i in range(10):
x += i
m = 3
while m < 8:
m += 1
if x < 0:
return 0
else:
return x
class TestStaticAnalysis(unittest.TestCase):
def _check_wrapper(self, wrapper, node_to_wrapper_map):
self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper)
if wrapper.parent is not None:
self.assertTrue(wrapper in wrapper.parent.children)
children_ast_nodes = [
child for child in ast.iter_child_nodes(wrapper.node)
]
self.assertEqual(len(wrapper.children), len(children_ast_nodes))
for child in wrapper.children:
self.assertTrue(child.node in children_ast_nodes)
self._check_wrapper(child, node_to_wrapper_map)
def test_construct_node_wrapper(self):
for func in [func_to_test_1, func_to_test_2]:
test_source_code = inspect.getsource(func)
ast_root = ast.parse(test_source_code)
visitor = StaticAnalysisVisitor(ast_root)
wrapper_root = visitor.get_node_wrapper_root()
node_to_wrapper_map = visitor.get_node_to_wrapper_map()
self._check_wrapper(wrapper_root, node_to_wrapper_map)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册