diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py index a36d2c220fa878f4148b85643637bf741764ec63..fab33cbd137c92d3396f60d0335e31fe8d8e23f6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py @@ -17,5 +17,9 @@ from __future__ import print_function from . import ast_transformer from .ast_transformer import * +from . import static_analysis +from .static_analysis import * + __all__ = [] __all__ += ast_transformer.__all__ +__all__ += static_analysis.__all__ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index b2401b25432d819e148b21de6813e8f40389128f..258bca47ae970054515045092af4c11d583937ed 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -16,87 +16,9 @@ from __future__ import print_function import gast -__all__ = ['AstNodeWrapper', 'DygraphToStaticAst', 'StaticAnalysisVisitor'] +from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor - -class NodeVarType(object): - """ - Enum class of python variable types. We have to know some variable types - during compile time to transfer AST. For example, a string variable and a - tensor variable in if clause may lead to different conversion from dygraph - to static graph. - """ - UNKNOWN = 0 # Reserve for AST nodes have not known the type - STATEMENT = 1 # For nodes representing statement (non-variable type) - PADDLE_DYGRAPH_API = 2 - PADDLE_CONTROL_IF = 3 - PADDLE_CONTROL_WHILE = 4 - PADDLE_CONTROL_FOR = 5 - - NONE = 100 - INT = 101 - FLOAT = 102 - STRING = 103 - TENSOR = 104 - - -class AstNodeWrapper(object): - """ - Wrapper for python ast.node. We need a node wrapper because ast.node - doesn't store all required information when we are transforming AST. - We should collect additional information which the actual transformation - needs. - """ - - def __init__(self, node): - self.node = node - self.parent = None - self.children = [] - self.node_var_type = NodeVarType.UNKNOWN - - -class StaticAnalysisVisitor(object): - """ - A class that does static analysis - """ - - def __init__(self, ast_root=None): - if ast_root is not None: - self.run(ast_root) - - def run(self, ast_root): - self.node_wrapper_root = None - self.ancestor_wrappers = [] - self.node_to_wrapper_map = {} - self.dfs_visit(ast_root) - - def dfs_visit(self, node): - # AST reuses some ast.nodes, such as Param node of expr_context - if node not in self.node_to_wrapper_map: - cur_wrapper = AstNodeWrapper(node) - self.node_to_wrapper_map[node] = cur_wrapper - else: - cur_wrapper = self.node_to_wrapper_map[node] - - if self.node_wrapper_root is None: - self.node_wrapper_root = cur_wrapper - - if len(self.ancestor_wrappers) != 0: - last_wrapper = self.ancestor_wrappers[-1] - last_wrapper.children.append(cur_wrapper) - cur_wrapper.parent = last_wrapper - - self.ancestor_wrappers.append(cur_wrapper) - for child in gast.iter_child_nodes(node): - self.dfs_visit(child) - self.ancestor_wrappers.pop() - return cur_wrapper.node_var_type - - def get_node_wrapper_root(self): - return self.node_wrapper_root - - def get_node_to_wrapper_map(self): - return self.node_to_wrapper_map +__all__ = ['DygraphToStaticAst'] class DygraphToStaticAst(gast.NodeTransformer): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd3ef7266c91b24801b75fbe1309c521eb15a17 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -0,0 +1,319 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import astor +import gast +import inspect +import six +import warnings + +__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] + + +# TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two +# function code together when Yamei finish her PR. +def _is_paddle_dygraph_api(obj): + m = inspect.getmodule(obj) + return m is not None and m.__name__.startswith("paddle.fluid.dygraph") + + +# TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two +# function code together when Yamei finish her PR. +def is_dygraph_api(node): + assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" + func_src = astor.to_source(node.func) + try: + import paddle.fluid as fluid + return eval("_is_paddle_dygraph_api({})".format(func_src)) + except NameError: + return False + + +def _is_numpy_api_helper(obj): + m = inspect.getmodule(obj) + return m is not None and m.__name__.startswith("numpy") + + +def is_numpy_api(node): + assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" + func_str = astor.to_source(node.func) + try: + import numpy as np + module_result = eval("_is_numpy_api_helper({})".format(func_str)) + # BUG: np.random.uniform doesn't have module and cannot be analyzed + # TODO: find a better way + if not module_result: + return func_str.startswith("numpy.") or func_str.startswith("np.") + except NameError: + return False + + +class NodeVarType(object): + """ + Enum class of python variable types. We have to know some variable types + during compile time to transfer AST. For example, a string variable and a + tensor variable in if clause may lead to different conversion from dygraph + to static graph. + """ + ERROR = -1 # Returns when static analysis gets error + UNKNOWN = 0 # Reserve for AST nodes have not known the type + STATEMENT = 1 # For nodes representing statement (non-variable type) + CALLABLE = 2 + + # python data types + NONE = 100 + BOOLEAN = 101 + INT = 102 + FLOAT = 103 + STRING = 104 + TENSOR = 105 + NUMPY_NDARRAY = 106 + + # python collections + LIST = 200 + SET = 201 + DICT = 202 + + PADDLE_DYGRAPH_API = 300 + PADDLE_CONTROL_IF = 301 + PADDLE_CONTROL_WHILE = 302 + PADDLE_CONTROL_FOR = 303 + + @staticmethod + def binary_op_output_type(in_type1, in_type2): + if in_type1 == in_type2: + return in_type1 + + if in_type1 == NodeVarType.UNKNOWN: + return in_type2 + if in_type2 == NodeVarType.UNKNOWN: + return in_type1 + + supported_types = [ + NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, + NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR + ] + + if in_type1 not in supported_types: + warnings.warn("Binary Op on un supported in_type1 = %d " % + (in_type1)) + return NodeVarType.UNKNOWN + if in_type2 not in supported_types: + warnings.warn("Binary Op on un supported in_type2 = %d " % + (in_type2)) + return NodeVarType.UNKNOWN + + forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR] + if in_type1 in forbidden_types and in_type2 in forbidden_types: + warnings.warn( + "Binary Op on un supported types: in_type1 = %d, in_type2 = %d" + % (in_type1, in_type2)) + return NodeVarType.UNKNOWN + return max(in_type1, in_type2) + + +class AstNodeWrapper(object): + """ + Wrapper for python gast.node. We need a node wrapper because gast.node + doesn't store all required information when we are transforming AST. + We should collect additional information which the actual transformation + needs. + """ + + def __init__(self, node): + self.node = node + self.parent = None + self.children = [] + self.node_var_type = NodeVarType.UNKNOWN + + +class AstVarScope(object): + """ + AstVarScope is a class holding the map from current scope variable to its + type. + """ + + def __init__(self, parent_scope=None): + self.sub_scopes = [] + self.name_to_id = {} + self.id_to_type = {} + self.cur_id = 0 + self.parent_scope = parent_scope + if parent_scope is not None: + parent_scope.sub_scopes.append(self) + + def set_var_type(self, var_name, node_var_type): + if var_name in self.name_to_id: + num_id = self.name_to_id[var_name] + else: + num_id = self.cur_id + self.cur_id += 1 + self.name_to_id[var_name] = num_id + self.id_to_type[num_id] = node_var_type + + def get_var_type(self, var_name): + if var_name in self.name_to_id: + num_id = self.name_to_id[var_name] + return self.id_to_type[num_id] + if self.parent_scope is None: + return NodeVarType.UNKNOWN + return self.parent_scope.get_var_type(var_name) + + +class AstVarEnv(object): + """ + A class maintains scopes and mapping from variable name to type. + """ + + def __init__(self): + self.cur_scope = AstVarScope() + + def enter_scope(self): + self.cur_scope = AstVarScope(parent_scope=self.cur_scope) + return self.cur_scope + + def exit_scope(self): + assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\ + "AstVarEnv when current scope doens't have parent scope." + self.cur_scope = self.cur_scope.parent_scope + return self.cur_scope + + def set_var_type(self, var_name, node_var_type): + self.cur_scope.set_var_type(var_name, node_var_type) + + def get_var_type(self, var_name): + return self.cur_scope.get_var_type(var_name) + + def get_scope_var_type(self): + ''' + Returns a dict mapping from variable name to type. Used for debug and + test. + ''' + cur_scope_dict = {} + for name in self.cur_scope.name_to_id: + node_var_type = self.cur_scope.get_var_type(name) + cur_scope_dict[name] = node_var_type + return cur_scope_dict + + +class StaticAnalysisVisitor(object): + """ + A class that does static analysis + """ + + def __init__(self, ast_root=None): + if ast_root is not None: + self.run(ast_root) + + def run(self, ast_root): + self.node_wrapper_root = None + self.ancestor_wrappers = [] + self.node_to_wrapper_map = {} + self.var_env = AstVarEnv() + + self.dfs_visit(ast_root) + + def dfs_visit(self, node): + # AST reuses some gast.nodes, such as Param node of expr_context + if node not in self.node_to_wrapper_map: + cur_wrapper = AstNodeWrapper(node) + self.node_to_wrapper_map[node] = cur_wrapper + else: + cur_wrapper = self.node_to_wrapper_map[node] + + if self.node_wrapper_root is None: + self.node_wrapper_root = cur_wrapper + + if len(self.ancestor_wrappers) != 0: + last_wrapper = self.ancestor_wrappers[-1] + last_wrapper.children.append(cur_wrapper) + cur_wrapper.parent = last_wrapper + + self.ancestor_wrappers.append(cur_wrapper) + for child in gast.iter_child_nodes(node): + self.dfs_visit(child) + self.ancestor_wrappers.pop() + + cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper) + return cur_wrapper.node_var_type + + def get_node_wrapper_root(self): + return self.node_wrapper_root + + def get_node_to_wrapper_map(self): + return self.node_to_wrapper_map + + def get_var_env(self): + return self.var_env + + def _get_node_var_type(self, cur_wrapper): + node = cur_wrapper.node + if isinstance(node, gast.Constant): + # singleton: None, True or False + if node.value is None: + return NodeVarType.NONE + if isinstance(node.value, bool): + return NodeVarType.BOOLEAN + if isinstance(node.value, int): + return NodeVarType.INT + if isinstance(node.value, float): + return NodeVarType.FLOAT + if isinstance(node.value, str): + return NodeVarType.STRING + + if isinstance(node, gast.BoolOp): + return NodeVarType.BOOLEAN + if isinstance(node, gast.Compare): + return NodeVarType.BOOLEAN + + if isinstance(node, gast.Dict): + return NodeVarType.DICT + if isinstance(node, gast.Set): + return NodeVarType.SET + + if isinstance(node, gast.UnaryOp): + return self.node_to_wrapper_map[node.operand].node_var_type + + if isinstance(node, gast.BinOp): + left_type = self.node_to_wrapper_map[node.left].node_var_type + right_type = self.node_to_wrapper_map[node.right].node_var_type + return NodeVarType.binary_op_output_type(left_type, right_type) + + if isinstance(node, gast.Assign): + ret_type = self.node_to_wrapper_map[node.value].node_var_type + for target in node.targets: + if isinstance(target, gast.Name): + self.node_to_wrapper_map[target].node_var_type = ret_type + self.var_env.set_var_type(target.id, ret_type) + return ret_type + + if isinstance(node, gast.Name): + if node.id == "None": + return NodeVarType.NONE + if node.id == "True" or node.id == "False": + return NodeVarType.BOOLEAN + return self.var_env.get_var_type(node.id) + + if isinstance(node, gast.Call): + if is_dygraph_api(node): + api_name = node.func.attr + if api_name == "to_variable": + return NodeVarType.TENSOR + if is_numpy_api(node): + # In this simple version we assume numpy api returns nd-array + return NodeVarType.NUMPY_NDARRAY + + return NodeVarType.STATEMENT diff --git a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py index f8937058000d3065474a52707d82c1f6397f51f2..fed55d2169daa29bc252bbf94de8aa03c16be56c 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py @@ -14,18 +14,23 @@ from __future__ import print_function -import ast +import gast import inspect +import numpy as np +import paddle.fluid as fluid import unittest -from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor -def func_to_test_1(a, b): +def func_to_test1(a, b): return a + b -def func_to_test_2(x): +result_var_type1 = {} + + +def func_to_test2(x): for i in range(10): x += i m = 3 @@ -37,6 +42,57 @@ def func_to_test_2(x): return x +result_var_type2 = {'m': NodeVarType.INT} + + +def func_to_test3(): + a = 1 + b = 3.0 + c = a * b + d = True + c + e = a < b + f = 9 * (a * 4) + g = "dddy" + h = None + i = False + j = None + 1 + + +result_var_type3 = { + 'a': NodeVarType.INT, + 'b': NodeVarType.FLOAT, + 'c': NodeVarType.FLOAT, + 'd': NodeVarType.FLOAT, + 'e': NodeVarType.BOOLEAN, + 'f': NodeVarType.INT, + 'g': NodeVarType.STRING, + 'h': NodeVarType.NONE, + 'i': NodeVarType.BOOLEAN, + 'j': NodeVarType.UNKNOWN +} + + +def func_to_test4(): + with fluid.dygraph.guard(): + a = np.random.uniform(0.1, 1, [1, 2]) + b = 1 + a + c = fluid.dygraph.to_variable(b) + d = (c + 1) * 0.3 + + +result_var_type4 = { + 'a': NodeVarType.NUMPY_NDARRAY, + 'b': NodeVarType.NUMPY_NDARRAY, + 'c': NodeVarType.TENSOR, + 'd': NodeVarType.TENSOR +} + +test_funcs = [func_to_test1, func_to_test2, func_to_test3, func_to_test4] +result_var_type = [ + result_var_type1, result_var_type2, result_var_type3, result_var_type4 +] + + class TestStaticAnalysis(unittest.TestCase): def _check_wrapper(self, wrapper, node_to_wrapper_map): self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper) @@ -44,7 +100,7 @@ class TestStaticAnalysis(unittest.TestCase): self.assertTrue(wrapper in wrapper.parent.children) children_ast_nodes = [ - child for child in ast.iter_child_nodes(wrapper.node) + child for child in gast.iter_child_nodes(wrapper.node) ] self.assertEqual(len(wrapper.children), len(children_ast_nodes)) for child in wrapper.children: @@ -52,15 +108,30 @@ class TestStaticAnalysis(unittest.TestCase): self._check_wrapper(child, node_to_wrapper_map) def test_construct_node_wrapper(self): - for func in [func_to_test_1, func_to_test_2]: + for func in test_funcs: test_source_code = inspect.getsource(func) - ast_root = ast.parse(test_source_code) - + ast_root = gast.parse(test_source_code) visitor = StaticAnalysisVisitor(ast_root) wrapper_root = visitor.get_node_wrapper_root() node_to_wrapper_map = visitor.get_node_to_wrapper_map() self._check_wrapper(wrapper_root, node_to_wrapper_map) + def test_var_env(self): + for i in range(4): + func = test_funcs[i] + var_type = result_var_type[i] + test_source_code = inspect.getsource(func) + ast_root = gast.parse(test_source_code) + print(gast.dump(ast_root)) + visitor = StaticAnalysisVisitor(ast_root) + var_env = visitor.get_var_env() + scope_var_type = var_env.get_scope_var_type() + self.assertEqual(len(scope_var_type), len(var_type)) + for name in scope_var_type: + print("Test var name %s" % (name)) + self.assertTrue(name in var_type) + self.assertEqual(scope_var_type[name], var_type[name]) + if __name__ == '__main__': unittest.main() diff --git a/python/requirements.txt b/python/requirements.txt index a7389ea7dbf2166b55daa7904560238b4c9e382f..6c82f5a1a6269ee88f441145c7f439138c20737d 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -18,3 +18,4 @@ decorator prettytable objgraph gast +astor