From 66991218ecc237ce7a7a2d04e3149004ec7e17dc Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 26 Feb 2020 10:11:33 +0800 Subject: [PATCH] Add Basic Function Return Type Analysis (#22747) 1. Considering functions, I have to change the node type from single value to a set. Because python function is allowed to return different types. The set represent all possible types 2. I added scope_name and scope_type for AstVarScope, because in python functions, variable may have different scope. For example: ``` a = 3 def foo(b): a = 9 return a + b ``` the `a` in `foo` is different to the `a` out of `foo`. Similar to class field. The scope_name will help me to know the function name when static analysis finds a `return` sentence. --- .../dygraph_to_static/static_analysis.py | 137 +++++++++++++----- .../test_ast_transformer_static_analysis.py | 74 +++++++--- 2 files changed, 157 insertions(+), 54 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 0ca36122fb..c8ab1a4586 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -25,34 +25,41 @@ __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] # TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two # function code together when Yamei finish her PR. -def _is_paddle_dygraph_api(obj): +def _is_api_in_module_helper(obj, module_prefix): m = inspect.getmodule(obj) - return m is not None and m.__name__.startswith("paddle.fluid.dygraph") + return m is not None and m.__name__.startswith(module_prefix) # TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two # function code together when Yamei finish her PR. -def is_dygraph_api(node): +def is_api_in_module(node, module_prefix): assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" - func_src = astor.to_source(node.func) + func_str = astor.to_source(gast.gast_to_ast(node.func)) try: import paddle.fluid as fluid - return eval("_is_paddle_dygraph_api({})".format(func_src)) + import paddle + return eval("_is_api_in_module_helper({}, '{}')".format(func_str, + module_prefix)) except NameError: return False -def _is_numpy_api_helper(obj): - m = inspect.getmodule(obj) - return m is not None and m.__name__.startswith("numpy") +def is_dygraph_api(node): + return is_api_in_module(node, "paddle.fluid.dygraph") + +def is_paddle_api(node): + return is_api_in_module(node, "paddle.fluid") + +# Is numpy_api cannot reuse is_api_in_module because of numpy module problem def is_numpy_api(node): assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" - func_str = astor.to_source(node.func) + func_str = astor.to_source(gast.gast_to_ast(node.func)) try: import numpy as np - module_result = eval("_is_numpy_api_helper({})".format(func_str)) + module_result = eval("_is_api_in_module_helper({}, '{}')".format( + func_str, "numpy")) # BUG: np.random.uniform doesn't have module and cannot be analyzed # TODO: find a better way if not module_result: @@ -91,6 +98,9 @@ class NodeVarType(object): PADDLE_CONTROL_IF = 301 PADDLE_CONTROL_WHILE = 302 PADDLE_CONTROL_FOR = 303 + # Paddle API may not be visible to get source code. + # We use this enum value to denote the type return by a Paddle API + PADDLE_RETURN_TYPES = 304 @staticmethod def binary_op_output_type(in_type1, in_type2): @@ -137,7 +147,7 @@ class AstNodeWrapper(object): self.node = node self.parent = None self.children = [] - self.node_var_type = NodeVarType.UNKNOWN + self.node_var_type = {NodeVarType.UNKNOWN} class AstVarScope(object): @@ -145,16 +155,35 @@ class AstVarScope(object): AstVarScope is a class holding the map from current scope variable to its type. """ - - def __init__(self, parent_scope=None): + SCOPE_TYPE_SCRIPT = 0 + SCOPE_TYPE_FUNCTION = 1 + SCOPE_TYPE_CLASS = 2 + + def __init__(self, + scope_name='', + scope_type=SCOPE_TYPE_SCRIPT, + parent_scope=None): self.sub_scopes = [] self.name_to_id = {} self.id_to_type = {} self.cur_id = 0 + + self.scope_name = scope_name + self.scope_type = scope_type self.parent_scope = parent_scope if parent_scope is not None: parent_scope.sub_scopes.append(self) + def add_var_type(self, var_name, node_var_type): + var_type = self.get_var_type(var_name) + if var_type == {NodeVarType.UNKNOWN}: + self.set_var_type(var_name, node_var_type) + else: + if isinstance(node_var_type, set): + var_type.update(node_var_type) + else: + var_type.add(node_var_type) + def set_var_type(self, var_name, node_var_type): if var_name in self.name_to_id: num_id = self.name_to_id[var_name] @@ -162,27 +191,29 @@ class AstVarScope(object): num_id = self.cur_id self.cur_id += 1 self.name_to_id[var_name] = num_id - self.id_to_type[num_id] = node_var_type + self.id_to_type[num_id] = node_var_type if isinstance( + node_var_type, set) else {node_var_type} def get_var_type(self, var_name): if var_name in self.name_to_id: num_id = self.name_to_id[var_name] return self.id_to_type[num_id] if self.parent_scope is None: - return NodeVarType.UNKNOWN + return {NodeVarType.UNKNOWN} return self.parent_scope.get_var_type(var_name) class AstVarEnv(object): """ - A class maintains scopes and mapping from variable name to type. + A class maintains scopes and mapping from name strings to type. """ def __init__(self): self.cur_scope = AstVarScope() - def enter_scope(self): - self.cur_scope = AstVarScope(parent_scope=self.cur_scope) + def enter_scope(self, scope_name, scope_type): + self.cur_scope = AstVarScope( + scope_name, scope_type, parent_scope=self.cur_scope) return self.cur_scope def exit_scope(self): @@ -191,6 +222,14 @@ class AstVarEnv(object): self.cur_scope = self.cur_scope.parent_scope return self.cur_scope + def get_parent_scope(self): + assert self.cur_scope.parent_scope is not None, "Call parent_scope in "\ + "AstVarEnv when current scope doesn't have parent scope." + return self.cur_scope.parent_scope + + def add_var_type(self, var_name, node_var_type): + self.cur_scope.add_var_type(var_name, node_var_type) + def set_var_type(self, var_name, node_var_type): self.cur_scope.set_var_type(var_name, node_var_type) @@ -244,7 +283,16 @@ class StaticAnalysisVisitor(object): self.ancestor_wrappers.append(cur_wrapper) for child in gast.iter_child_nodes(node): - self.dfs_visit(child) + if isinstance(child, gast.FunctionDef) or isinstance( + child, gast.AsyncFunctionDef): + # TODO: current version is function name mapping to its type + # consider complex case involving parameters + self.var_env.enter_scope(child.name, + AstVarScope.SCOPE_TYPE_FUNCTION) + func_type = self.dfs_visit(child) + self.var_env.exit_scope() + else: + self.dfs_visit(child) self.ancestor_wrappers.pop() cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper) @@ -264,25 +312,25 @@ class StaticAnalysisVisitor(object): if isinstance(node, gast.Constant): # singleton: None, True or False if node.value is None: - return NodeVarType.NONE + return {NodeVarType.NONE} if isinstance(node.value, bool): - return NodeVarType.BOOLEAN + return {NodeVarType.BOOLEAN} if isinstance(node.value, int): - return NodeVarType.INT + return {NodeVarType.INT} if isinstance(node.value, float): - return NodeVarType.FLOAT + return {NodeVarType.FLOAT} if isinstance(node.value, str): - return NodeVarType.STRING + return {NodeVarType.STRING} if isinstance(node, gast.BoolOp): - return NodeVarType.BOOLEAN + return {NodeVarType.BOOLEAN} if isinstance(node, gast.Compare): - return NodeVarType.BOOLEAN + return {NodeVarType.BOOLEAN} if isinstance(node, gast.Dict): - return NodeVarType.DICT + return {NodeVarType.DICT} if isinstance(node, gast.Set): - return NodeVarType.SET + return {NodeVarType.SET} if isinstance(node, gast.UnaryOp): return self.node_to_wrapper_map[node.operand].node_var_type @@ -290,7 +338,11 @@ class StaticAnalysisVisitor(object): if isinstance(node, gast.BinOp): left_type = self.node_to_wrapper_map[node.left].node_var_type right_type = self.node_to_wrapper_map[node.right].node_var_type - return NodeVarType.binary_op_output_type(left_type, right_type) + result_type = set() + for l in left_type: + for r in right_type: + result_type.add(NodeVarType.binary_op_output_type(l, r)) + return result_type if isinstance(node, gast.Assign): ret_type = self.node_to_wrapper_map[node.value].node_var_type @@ -302,18 +354,31 @@ class StaticAnalysisVisitor(object): if isinstance(node, gast.Name): if node.id == "None": - return NodeVarType.NONE + return {NodeVarType.NONE} if node.id == "True" or node.id == "False": - return NodeVarType.BOOLEAN + return {NodeVarType.BOOLEAN} return self.var_env.get_var_type(node.id) + if isinstance(node, gast.Return): + return_type = self.node_to_wrapper_map[node.value].node_var_type + assert self.var_env.cur_scope.scope_type == AstVarScope.SCOPE_TYPE_FUNCTION, "Return at non-function scope" + func_name = self.var_env.cur_scope.scope_name + parent_scope = self.var_env.get_parent_scope() + parent_scope.add_var_type(func_name, return_type) + return return_type + if isinstance(node, gast.Call): if is_dygraph_api(node): - api_name = node.func.attr - if api_name == "to_variable": - return NodeVarType.TENSOR + if isinstance(node.func, gast.Attribute): + if node.func.attr == "to_variable": + return {NodeVarType.TENSOR} + if is_paddle_api(node): + return {NodeVarType.PADDLE_RETURN_TYPES} if is_numpy_api(node): # In this simple version we assume numpy api returns nd-array - return NodeVarType.NUMPY_NDARRAY + return {NodeVarType.NUMPY_NDARRAY} + + if isinstance(node.func, gast.Name): + return self.var_env.get_var_type(node.func.id) - return NodeVarType.STATEMENT + return {NodeVarType.STATEMENT} diff --git a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py index fed55d2169..6822397b1f 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/test_ast_transformer_static_analysis.py @@ -42,7 +42,7 @@ def func_to_test2(x): return x -result_var_type2 = {'m': NodeVarType.INT} +result_var_type2 = {'m': {NodeVarType.INT}} def func_to_test3(): @@ -59,16 +59,16 @@ def func_to_test3(): result_var_type3 = { - 'a': NodeVarType.INT, - 'b': NodeVarType.FLOAT, - 'c': NodeVarType.FLOAT, - 'd': NodeVarType.FLOAT, - 'e': NodeVarType.BOOLEAN, - 'f': NodeVarType.INT, - 'g': NodeVarType.STRING, - 'h': NodeVarType.NONE, - 'i': NodeVarType.BOOLEAN, - 'j': NodeVarType.UNKNOWN + 'a': {NodeVarType.INT}, + 'b': {NodeVarType.FLOAT}, + 'c': {NodeVarType.FLOAT}, + 'd': {NodeVarType.FLOAT}, + 'e': {NodeVarType.BOOLEAN}, + 'f': {NodeVarType.INT}, + 'g': {NodeVarType.STRING}, + 'h': {NodeVarType.NONE}, + 'i': {NodeVarType.BOOLEAN}, + 'j': {NodeVarType.UNKNOWN} } @@ -81,15 +81,48 @@ def func_to_test4(): result_var_type4 = { - 'a': NodeVarType.NUMPY_NDARRAY, - 'b': NodeVarType.NUMPY_NDARRAY, - 'c': NodeVarType.TENSOR, - 'd': NodeVarType.TENSOR + 'a': {NodeVarType.NUMPY_NDARRAY}, + 'b': {NodeVarType.NUMPY_NDARRAY}, + 'c': {NodeVarType.TENSOR}, + 'd': {NodeVarType.TENSOR} } -test_funcs = [func_to_test1, func_to_test2, func_to_test3, func_to_test4] + +def func_to_test5(): + def inner_int_func(): + return 1 + + def inner_bool_float_func(x): + a = 1.0 + if x > 0: + return a + return False + + def inner_unknown_func(x): + return x + + a = inner_int_func() + b = inner_bool_float_func(3) + c = inner_unknown_func(None) + d = paddle.fluid.data('x', [1, 2]) + + +result_var_type5 = { + 'a': {NodeVarType.INT}, + 'b': {NodeVarType.FLOAT, NodeVarType.BOOLEAN}, + 'c': {NodeVarType.UNKNOWN}, + 'd': {NodeVarType.PADDLE_RETURN_TYPES}, + 'inner_int_func': {NodeVarType.INT}, + 'inner_bool_float_func': {NodeVarType.FLOAT, NodeVarType.BOOLEAN}, + 'inner_unknown_func': {NodeVarType.UNKNOWN}, +} + +test_funcs = [ + func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5 +] result_var_type = [ - result_var_type1, result_var_type2, result_var_type3, result_var_type4 + result_var_type1, result_var_type2, result_var_type3, result_var_type4, + result_var_type5 ] @@ -117,7 +150,7 @@ class TestStaticAnalysis(unittest.TestCase): self._check_wrapper(wrapper_root, node_to_wrapper_map) def test_var_env(self): - for i in range(4): + for i in range(5): func = test_funcs[i] var_type = result_var_type[i] test_source_code = inspect.getsource(func) @@ -125,6 +158,11 @@ class TestStaticAnalysis(unittest.TestCase): print(gast.dump(ast_root)) visitor = StaticAnalysisVisitor(ast_root) var_env = visitor.get_var_env() + + # There must be 1 sub scope for the test function + self.assertEqual(1, len(var_env.cur_scope.sub_scopes)) + var_env.cur_scope = var_env.cur_scope.sub_scopes[0] + scope_var_type = var_env.get_scope_var_type() self.assertEqual(len(scope_var_type), len(var_type)) for name in scope_var_type: -- GitLab