未验证 提交 66991218 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Basic Function Return Type Analysis (#22747)

1. Considering functions, I have to change the node type from single value to a set. Because python function is allowed to return different types. The set represent all possible types

2. I added scope_name and scope_type for AstVarScope, because in python functions, variable may have different scope. For example:

```
a = 3
def foo(b):
    a = 9
    return a + b
```

the `a` in `foo` is different to the `a` out of `foo`. Similar to class field. The scope_name will help me to know the function name when static analysis finds a `return` sentence.
上级 60aaa715
...@@ -25,34 +25,41 @@ __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] ...@@ -25,34 +25,41 @@ __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
# TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two # TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR. # function code together when Yamei finish her PR.
def _is_paddle_dygraph_api(obj): def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj) m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith("paddle.fluid.dygraph") return m is not None and m.__name__.startswith(module_prefix)
# TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two # TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR. # function code together when Yamei finish her PR.
def is_dygraph_api(node): def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_src = astor.to_source(node.func) func_str = astor.to_source(gast.gast_to_ast(node.func))
try: try:
import paddle.fluid as fluid import paddle.fluid as fluid
return eval("_is_paddle_dygraph_api({})".format(func_src)) import paddle
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError: except NameError:
return False return False
def _is_numpy_api_helper(obj): def is_dygraph_api(node):
m = inspect.getmodule(obj) return is_api_in_module(node, "paddle.fluid.dygraph")
return m is not None and m.__name__.startswith("numpy")
def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node): def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(node.func) func_str = astor.to_source(gast.gast_to_ast(node.func))
try: try:
import numpy as np import numpy as np
module_result = eval("_is_numpy_api_helper({})".format(func_str)) module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy"))
# BUG: np.random.uniform doesn't have module and cannot be analyzed # BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way # TODO: find a better way
if not module_result: if not module_result:
...@@ -91,6 +98,9 @@ class NodeVarType(object): ...@@ -91,6 +98,9 @@ class NodeVarType(object):
PADDLE_CONTROL_IF = 301 PADDLE_CONTROL_IF = 301
PADDLE_CONTROL_WHILE = 302 PADDLE_CONTROL_WHILE = 302
PADDLE_CONTROL_FOR = 303 PADDLE_CONTROL_FOR = 303
# Paddle API may not be visible to get source code.
# We use this enum value to denote the type return by a Paddle API
PADDLE_RETURN_TYPES = 304
@staticmethod @staticmethod
def binary_op_output_type(in_type1, in_type2): def binary_op_output_type(in_type1, in_type2):
...@@ -137,7 +147,7 @@ class AstNodeWrapper(object): ...@@ -137,7 +147,7 @@ class AstNodeWrapper(object):
self.node = node self.node = node
self.parent = None self.parent = None
self.children = [] self.children = []
self.node_var_type = NodeVarType.UNKNOWN self.node_var_type = {NodeVarType.UNKNOWN}
class AstVarScope(object): class AstVarScope(object):
...@@ -145,16 +155,35 @@ class AstVarScope(object): ...@@ -145,16 +155,35 @@ class AstVarScope(object):
AstVarScope is a class holding the map from current scope variable to its AstVarScope is a class holding the map from current scope variable to its
type. type.
""" """
SCOPE_TYPE_SCRIPT = 0
def __init__(self, parent_scope=None): SCOPE_TYPE_FUNCTION = 1
SCOPE_TYPE_CLASS = 2
def __init__(self,
scope_name='',
scope_type=SCOPE_TYPE_SCRIPT,
parent_scope=None):
self.sub_scopes = [] self.sub_scopes = []
self.name_to_id = {} self.name_to_id = {}
self.id_to_type = {} self.id_to_type = {}
self.cur_id = 0 self.cur_id = 0
self.scope_name = scope_name
self.scope_type = scope_type
self.parent_scope = parent_scope self.parent_scope = parent_scope
if parent_scope is not None: if parent_scope is not None:
parent_scope.sub_scopes.append(self) parent_scope.sub_scopes.append(self)
def add_var_type(self, var_name, node_var_type):
var_type = self.get_var_type(var_name)
if var_type == {NodeVarType.UNKNOWN}:
self.set_var_type(var_name, node_var_type)
else:
if isinstance(node_var_type, set):
var_type.update(node_var_type)
else:
var_type.add(node_var_type)
def set_var_type(self, var_name, node_var_type): def set_var_type(self, var_name, node_var_type):
if var_name in self.name_to_id: if var_name in self.name_to_id:
num_id = self.name_to_id[var_name] num_id = self.name_to_id[var_name]
...@@ -162,27 +191,29 @@ class AstVarScope(object): ...@@ -162,27 +191,29 @@ class AstVarScope(object):
num_id = self.cur_id num_id = self.cur_id
self.cur_id += 1 self.cur_id += 1
self.name_to_id[var_name] = num_id self.name_to_id[var_name] = num_id
self.id_to_type[num_id] = node_var_type self.id_to_type[num_id] = node_var_type if isinstance(
node_var_type, set) else {node_var_type}
def get_var_type(self, var_name): def get_var_type(self, var_name):
if var_name in self.name_to_id: if var_name in self.name_to_id:
num_id = self.name_to_id[var_name] num_id = self.name_to_id[var_name]
return self.id_to_type[num_id] return self.id_to_type[num_id]
if self.parent_scope is None: if self.parent_scope is None:
return NodeVarType.UNKNOWN return {NodeVarType.UNKNOWN}
return self.parent_scope.get_var_type(var_name) return self.parent_scope.get_var_type(var_name)
class AstVarEnv(object): class AstVarEnv(object):
""" """
A class maintains scopes and mapping from variable name to type. A class maintains scopes and mapping from name strings to type.
""" """
def __init__(self): def __init__(self):
self.cur_scope = AstVarScope() self.cur_scope = AstVarScope()
def enter_scope(self): def enter_scope(self, scope_name, scope_type):
self.cur_scope = AstVarScope(parent_scope=self.cur_scope) self.cur_scope = AstVarScope(
scope_name, scope_type, parent_scope=self.cur_scope)
return self.cur_scope return self.cur_scope
def exit_scope(self): def exit_scope(self):
...@@ -191,6 +222,14 @@ class AstVarEnv(object): ...@@ -191,6 +222,14 @@ class AstVarEnv(object):
self.cur_scope = self.cur_scope.parent_scope self.cur_scope = self.cur_scope.parent_scope
return self.cur_scope return self.cur_scope
def get_parent_scope(self):
assert self.cur_scope.parent_scope is not None, "Call parent_scope in "\
"AstVarEnv when current scope doesn't have parent scope."
return self.cur_scope.parent_scope
def add_var_type(self, var_name, node_var_type):
self.cur_scope.add_var_type(var_name, node_var_type)
def set_var_type(self, var_name, node_var_type): def set_var_type(self, var_name, node_var_type):
self.cur_scope.set_var_type(var_name, node_var_type) self.cur_scope.set_var_type(var_name, node_var_type)
...@@ -244,6 +283,15 @@ class StaticAnalysisVisitor(object): ...@@ -244,6 +283,15 @@ class StaticAnalysisVisitor(object):
self.ancestor_wrappers.append(cur_wrapper) self.ancestor_wrappers.append(cur_wrapper)
for child in gast.iter_child_nodes(node): for child in gast.iter_child_nodes(node):
if isinstance(child, gast.FunctionDef) or isinstance(
child, gast.AsyncFunctionDef):
# TODO: current version is function name mapping to its type
# consider complex case involving parameters
self.var_env.enter_scope(child.name,
AstVarScope.SCOPE_TYPE_FUNCTION)
func_type = self.dfs_visit(child)
self.var_env.exit_scope()
else:
self.dfs_visit(child) self.dfs_visit(child)
self.ancestor_wrappers.pop() self.ancestor_wrappers.pop()
...@@ -264,25 +312,25 @@ class StaticAnalysisVisitor(object): ...@@ -264,25 +312,25 @@ class StaticAnalysisVisitor(object):
if isinstance(node, gast.Constant): if isinstance(node, gast.Constant):
# singleton: None, True or False # singleton: None, True or False
if node.value is None: if node.value is None:
return NodeVarType.NONE return {NodeVarType.NONE}
if isinstance(node.value, bool): if isinstance(node.value, bool):
return NodeVarType.BOOLEAN return {NodeVarType.BOOLEAN}
if isinstance(node.value, int): if isinstance(node.value, int):
return NodeVarType.INT return {NodeVarType.INT}
if isinstance(node.value, float): if isinstance(node.value, float):
return NodeVarType.FLOAT return {NodeVarType.FLOAT}
if isinstance(node.value, str): if isinstance(node.value, str):
return NodeVarType.STRING return {NodeVarType.STRING}
if isinstance(node, gast.BoolOp): if isinstance(node, gast.BoolOp):
return NodeVarType.BOOLEAN return {NodeVarType.BOOLEAN}
if isinstance(node, gast.Compare): if isinstance(node, gast.Compare):
return NodeVarType.BOOLEAN return {NodeVarType.BOOLEAN}
if isinstance(node, gast.Dict): if isinstance(node, gast.Dict):
return NodeVarType.DICT return {NodeVarType.DICT}
if isinstance(node, gast.Set): if isinstance(node, gast.Set):
return NodeVarType.SET return {NodeVarType.SET}
if isinstance(node, gast.UnaryOp): if isinstance(node, gast.UnaryOp):
return self.node_to_wrapper_map[node.operand].node_var_type return self.node_to_wrapper_map[node.operand].node_var_type
...@@ -290,7 +338,11 @@ class StaticAnalysisVisitor(object): ...@@ -290,7 +338,11 @@ class StaticAnalysisVisitor(object):
if isinstance(node, gast.BinOp): if isinstance(node, gast.BinOp):
left_type = self.node_to_wrapper_map[node.left].node_var_type left_type = self.node_to_wrapper_map[node.left].node_var_type
right_type = self.node_to_wrapper_map[node.right].node_var_type right_type = self.node_to_wrapper_map[node.right].node_var_type
return NodeVarType.binary_op_output_type(left_type, right_type) result_type = set()
for l in left_type:
for r in right_type:
result_type.add(NodeVarType.binary_op_output_type(l, r))
return result_type
if isinstance(node, gast.Assign): if isinstance(node, gast.Assign):
ret_type = self.node_to_wrapper_map[node.value].node_var_type ret_type = self.node_to_wrapper_map[node.value].node_var_type
...@@ -302,18 +354,31 @@ class StaticAnalysisVisitor(object): ...@@ -302,18 +354,31 @@ class StaticAnalysisVisitor(object):
if isinstance(node, gast.Name): if isinstance(node, gast.Name):
if node.id == "None": if node.id == "None":
return NodeVarType.NONE return {NodeVarType.NONE}
if node.id == "True" or node.id == "False": if node.id == "True" or node.id == "False":
return NodeVarType.BOOLEAN return {NodeVarType.BOOLEAN}
return self.var_env.get_var_type(node.id) return self.var_env.get_var_type(node.id)
if isinstance(node, gast.Return):
return_type = self.node_to_wrapper_map[node.value].node_var_type
assert self.var_env.cur_scope.scope_type == AstVarScope.SCOPE_TYPE_FUNCTION, "Return at non-function scope"
func_name = self.var_env.cur_scope.scope_name
parent_scope = self.var_env.get_parent_scope()
parent_scope.add_var_type(func_name, return_type)
return return_type
if isinstance(node, gast.Call): if isinstance(node, gast.Call):
if is_dygraph_api(node): if is_dygraph_api(node):
api_name = node.func.attr if isinstance(node.func, gast.Attribute):
if api_name == "to_variable": if node.func.attr == "to_variable":
return NodeVarType.TENSOR return {NodeVarType.TENSOR}
if is_paddle_api(node):
return {NodeVarType.PADDLE_RETURN_TYPES}
if is_numpy_api(node): if is_numpy_api(node):
# In this simple version we assume numpy api returns nd-array # In this simple version we assume numpy api returns nd-array
return NodeVarType.NUMPY_NDARRAY return {NodeVarType.NUMPY_NDARRAY}
if isinstance(node.func, gast.Name):
return self.var_env.get_var_type(node.func.id)
return NodeVarType.STATEMENT return {NodeVarType.STATEMENT}
...@@ -42,7 +42,7 @@ def func_to_test2(x): ...@@ -42,7 +42,7 @@ def func_to_test2(x):
return x return x
result_var_type2 = {'m': NodeVarType.INT} result_var_type2 = {'m': {NodeVarType.INT}}
def func_to_test3(): def func_to_test3():
...@@ -59,16 +59,16 @@ def func_to_test3(): ...@@ -59,16 +59,16 @@ def func_to_test3():
result_var_type3 = { result_var_type3 = {
'a': NodeVarType.INT, 'a': {NodeVarType.INT},
'b': NodeVarType.FLOAT, 'b': {NodeVarType.FLOAT},
'c': NodeVarType.FLOAT, 'c': {NodeVarType.FLOAT},
'd': NodeVarType.FLOAT, 'd': {NodeVarType.FLOAT},
'e': NodeVarType.BOOLEAN, 'e': {NodeVarType.BOOLEAN},
'f': NodeVarType.INT, 'f': {NodeVarType.INT},
'g': NodeVarType.STRING, 'g': {NodeVarType.STRING},
'h': NodeVarType.NONE, 'h': {NodeVarType.NONE},
'i': NodeVarType.BOOLEAN, 'i': {NodeVarType.BOOLEAN},
'j': NodeVarType.UNKNOWN 'j': {NodeVarType.UNKNOWN}
} }
...@@ -81,15 +81,48 @@ def func_to_test4(): ...@@ -81,15 +81,48 @@ def func_to_test4():
result_var_type4 = { result_var_type4 = {
'a': NodeVarType.NUMPY_NDARRAY, 'a': {NodeVarType.NUMPY_NDARRAY},
'b': NodeVarType.NUMPY_NDARRAY, 'b': {NodeVarType.NUMPY_NDARRAY},
'c': NodeVarType.TENSOR, 'c': {NodeVarType.TENSOR},
'd': NodeVarType.TENSOR 'd': {NodeVarType.TENSOR}
} }
test_funcs = [func_to_test1, func_to_test2, func_to_test3, func_to_test4]
def func_to_test5():
def inner_int_func():
return 1
def inner_bool_float_func(x):
a = 1.0
if x > 0:
return a
return False
def inner_unknown_func(x):
return x
a = inner_int_func()
b = inner_bool_float_func(3)
c = inner_unknown_func(None)
d = paddle.fluid.data('x', [1, 2])
result_var_type5 = {
'a': {NodeVarType.INT},
'b': {NodeVarType.FLOAT, NodeVarType.BOOLEAN},
'c': {NodeVarType.UNKNOWN},
'd': {NodeVarType.PADDLE_RETURN_TYPES},
'inner_int_func': {NodeVarType.INT},
'inner_bool_float_func': {NodeVarType.FLOAT, NodeVarType.BOOLEAN},
'inner_unknown_func': {NodeVarType.UNKNOWN},
}
test_funcs = [
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5
]
result_var_type = [ result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4 result_var_type1, result_var_type2, result_var_type3, result_var_type4,
result_var_type5
] ]
...@@ -117,7 +150,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -117,7 +150,7 @@ class TestStaticAnalysis(unittest.TestCase):
self._check_wrapper(wrapper_root, node_to_wrapper_map) self._check_wrapper(wrapper_root, node_to_wrapper_map)
def test_var_env(self): def test_var_env(self):
for i in range(4): for i in range(5):
func = test_funcs[i] func = test_funcs[i]
var_type = result_var_type[i] var_type = result_var_type[i]
test_source_code = inspect.getsource(func) test_source_code = inspect.getsource(func)
...@@ -125,6 +158,11 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -125,6 +158,11 @@ class TestStaticAnalysis(unittest.TestCase):
print(gast.dump(ast_root)) print(gast.dump(ast_root))
visitor = StaticAnalysisVisitor(ast_root) visitor = StaticAnalysisVisitor(ast_root)
var_env = visitor.get_var_env() var_env = visitor.get_var_env()
# There must be 1 sub scope for the test function
self.assertEqual(1, len(var_env.cur_scope.sub_scopes))
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
scope_var_type = var_env.get_scope_var_type() scope_var_type = var_env.get_scope_var_type()
self.assertEqual(len(scope_var_type), len(var_type)) self.assertEqual(len(scope_var_type), len(var_type))
for name in scope_var_type: for name in scope_var_type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册