diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 2c5387f0521e94705a1dc129d95e6690d16e7ffc..a69d8fe6321f00f4124e347784d6a3cf4421c562 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -16,7 +16,7 @@ from __future__ import print_function import ast -__all__ = ['DygraphToStaticAst'] +__all__ = ['AstNodeWrapper', 'DygraphToStaticAst', 'StaticAnalysisVisitor'] class NodeVarType(object): @@ -51,9 +51,54 @@ class AstNodeWrapper(object): def __init__(self, node): self.node = node self.parent = None + self.children = [] self.node_var_type = NodeVarType.UNKNOWN +class StaticAnalysisVisitor(object): + """ + A class that does static analysis + """ + + def __init__(self, ast_root=None): + if ast_root is not None: + self.run(ast_root) + + def run(self, ast_root): + self.node_wrapper_root = None + self.ancestor_wrappers = [] + self.node_to_wrapper_map = {} + self.dfs_visit(ast_root) + + def dfs_visit(self, node): + # AST reuses some ast.nodes, such as Param node of expr_context + if node not in self.node_to_wrapper_map: + cur_wrapper = AstNodeWrapper(node) + self.node_to_wrapper_map[node] = cur_wrapper + else: + cur_wrapper = self.node_to_wrapper_map[node] + + if self.node_wrapper_root is None: + self.node_wrapper_root = cur_wrapper + + if len(self.ancestor_wrappers) != 0: + last_wrapper = self.ancestor_wrappers[-1] + last_wrapper.children.append(cur_wrapper) + cur_wrapper.parent = last_wrapper + + self.ancestor_wrappers.append(cur_wrapper) + for child in ast.iter_child_nodes(node): + self.dfs_visit(child) + self.ancestor_wrappers.pop() + return cur_wrapper.node_var_type + + def get_node_wrapper_root(self): + return self.node_wrapper_root + + def get_node_to_wrapper_map(self): + return self.node_to_wrapper_map + + class DygraphToStaticAst(ast.NodeTransformer): """ Main class to transform Dygraph to Static Graph @@ -62,15 +107,10 @@ class DygraphToStaticAst(ast.NodeTransformer): def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root - self.static_analysis_root = AstNodeWrapper(root) - self.visit(root) + self.static_analysis_root = StaticAnalysisVisitor( + root).get_node_wrapper_root() self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root - def visit(self, node): - # TODO construct a tree whose nodes are AstNodeWrapper - # This step also does static node type analysis - print("Not implemented") - def transfer_from_node_type(self, node): print("Not implemented") diff --git a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..f8937058000d3065474a52707d82c1f6397f51f2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import ast +import inspect +import unittest + +from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, StaticAnalysisVisitor + + +def func_to_test_1(a, b): + return a + b + + +def func_to_test_2(x): + for i in range(10): + x += i + m = 3 + while m < 8: + m += 1 + if x < 0: + return 0 + else: + return x + + +class TestStaticAnalysis(unittest.TestCase): + def _check_wrapper(self, wrapper, node_to_wrapper_map): + self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper) + if wrapper.parent is not None: + self.assertTrue(wrapper in wrapper.parent.children) + + children_ast_nodes = [ + child for child in ast.iter_child_nodes(wrapper.node) + ] + self.assertEqual(len(wrapper.children), len(children_ast_nodes)) + for child in wrapper.children: + self.assertTrue(child.node in children_ast_nodes) + self._check_wrapper(child, node_to_wrapper_map) + + def test_construct_node_wrapper(self): + for func in [func_to_test_1, func_to_test_2]: + test_source_code = inspect.getsource(func) + ast_root = ast.parse(test_source_code) + + visitor = StaticAnalysisVisitor(ast_root) + wrapper_root = visitor.get_node_wrapper_root() + node_to_wrapper_map = visitor.get_node_to_wrapper_map() + self._check_wrapper(wrapper_root, node_to_wrapper_map) + + +if __name__ == '__main__': + unittest.main()