未验证 提交 14672a63 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Basic Node Var Type Analysis (#22603)

1. Move AstNodeWrapper, StaticAnalysisVisitor to a new python file: static_analysis.py
2. Add basic node var type analysis
上级 a089072c
...@@ -17,5 +17,9 @@ from __future__ import print_function ...@@ -17,5 +17,9 @@ from __future__ import print_function
from . import ast_transformer from . import ast_transformer
from .ast_transformer import * from .ast_transformer import *
from . import static_analysis
from .static_analysis import *
__all__ = [] __all__ = []
__all__ += ast_transformer.__all__ __all__ += ast_transformer.__all__
__all__ += static_analysis.__all__
...@@ -16,87 +16,9 @@ from __future__ import print_function ...@@ -16,87 +16,9 @@ from __future__ import print_function
import gast import gast
__all__ = ['AstNodeWrapper', 'DygraphToStaticAst', 'StaticAnalysisVisitor'] from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst']
class NodeVarType(object):
"""
Enum class of python variable types. We have to know some variable types
during compile time to transfer AST. For example, a string variable and a
tensor variable in if clause may lead to different conversion from dygraph
to static graph.
"""
UNKNOWN = 0 # Reserve for AST nodes have not known the type
STATEMENT = 1 # For nodes representing statement (non-variable type)
PADDLE_DYGRAPH_API = 2
PADDLE_CONTROL_IF = 3
PADDLE_CONTROL_WHILE = 4
PADDLE_CONTROL_FOR = 5
NONE = 100
INT = 101
FLOAT = 102
STRING = 103
TENSOR = 104
class AstNodeWrapper(object):
"""
Wrapper for python ast.node. We need a node wrapper because ast.node
doesn't store all required information when we are transforming AST.
We should collect additional information which the actual transformation
needs.
"""
def __init__(self, node):
self.node = node
self.parent = None
self.children = []
self.node_var_type = NodeVarType.UNKNOWN
class StaticAnalysisVisitor(object):
"""
A class that does static analysis
"""
def __init__(self, ast_root=None):
if ast_root is not None:
self.run(ast_root)
def run(self, ast_root):
self.node_wrapper_root = None
self.ancestor_wrappers = []
self.node_to_wrapper_map = {}
self.dfs_visit(ast_root)
def dfs_visit(self, node):
# AST reuses some ast.nodes, such as Param node of expr_context
if node not in self.node_to_wrapper_map:
cur_wrapper = AstNodeWrapper(node)
self.node_to_wrapper_map[node] = cur_wrapper
else:
cur_wrapper = self.node_to_wrapper_map[node]
if self.node_wrapper_root is None:
self.node_wrapper_root = cur_wrapper
if len(self.ancestor_wrappers) != 0:
last_wrapper = self.ancestor_wrappers[-1]
last_wrapper.children.append(cur_wrapper)
cur_wrapper.parent = last_wrapper
self.ancestor_wrappers.append(cur_wrapper)
for child in gast.iter_child_nodes(node):
self.dfs_visit(child)
self.ancestor_wrappers.pop()
return cur_wrapper.node_var_type
def get_node_wrapper_root(self):
return self.node_wrapper_root
def get_node_to_wrapper_map(self):
return self.node_to_wrapper_map
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import astor
import gast
import inspect
import six
import warnings
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
# TODO: _is_paddle_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR.
def _is_paddle_dygraph_api(obj):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith("paddle.fluid.dygraph")
# TODO: is_dygraph_api is duplicated in Yamei's utils.py. Merge the two
# function code together when Yamei finish her PR.
def is_dygraph_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_src = astor.to_source(node.func)
try:
import paddle.fluid as fluid
return eval("_is_paddle_dygraph_api({})".format(func_src))
except NameError:
return False
def _is_numpy_api_helper(obj):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith("numpy")
def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(node.func)
try:
import numpy as np
module_result = eval("_is_numpy_api_helper({})".format(func_str))
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
if not module_result:
return func_str.startswith("numpy.") or func_str.startswith("np.")
except NameError:
return False
class NodeVarType(object):
"""
Enum class of python variable types. We have to know some variable types
during compile time to transfer AST. For example, a string variable and a
tensor variable in if clause may lead to different conversion from dygraph
to static graph.
"""
ERROR = -1 # Returns when static analysis gets error
UNKNOWN = 0 # Reserve for AST nodes have not known the type
STATEMENT = 1 # For nodes representing statement (non-variable type)
CALLABLE = 2
# python data types
NONE = 100
BOOLEAN = 101
INT = 102
FLOAT = 103
STRING = 104
TENSOR = 105
NUMPY_NDARRAY = 106
# python collections
LIST = 200
SET = 201
DICT = 202
PADDLE_DYGRAPH_API = 300
PADDLE_CONTROL_IF = 301
PADDLE_CONTROL_WHILE = 302
PADDLE_CONTROL_FOR = 303
@staticmethod
def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2:
return in_type1
if in_type1 == NodeVarType.UNKNOWN:
return in_type2
if in_type2 == NodeVarType.UNKNOWN:
return in_type1
supported_types = [
NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR
]
if in_type1 not in supported_types:
warnings.warn("Binary Op on un supported in_type1 = %d " %
(in_type1))
return NodeVarType.UNKNOWN
if in_type2 not in supported_types:
warnings.warn("Binary Op on un supported in_type2 = %d " %
(in_type2))
return NodeVarType.UNKNOWN
forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR]
if in_type1 in forbidden_types and in_type2 in forbidden_types:
warnings.warn(
"Binary Op on un supported types: in_type1 = %d, in_type2 = %d"
% (in_type1, in_type2))
return NodeVarType.UNKNOWN
return max(in_type1, in_type2)
class AstNodeWrapper(object):
"""
Wrapper for python gast.node. We need a node wrapper because gast.node
doesn't store all required information when we are transforming AST.
We should collect additional information which the actual transformation
needs.
"""
def __init__(self, node):
self.node = node
self.parent = None
self.children = []
self.node_var_type = NodeVarType.UNKNOWN
class AstVarScope(object):
"""
AstVarScope is a class holding the map from current scope variable to its
type.
"""
def __init__(self, parent_scope=None):
self.sub_scopes = []
self.name_to_id = {}
self.id_to_type = {}
self.cur_id = 0
self.parent_scope = parent_scope
if parent_scope is not None:
parent_scope.sub_scopes.append(self)
def set_var_type(self, var_name, node_var_type):
if var_name in self.name_to_id:
num_id = self.name_to_id[var_name]
else:
num_id = self.cur_id
self.cur_id += 1
self.name_to_id[var_name] = num_id
self.id_to_type[num_id] = node_var_type
def get_var_type(self, var_name):
if var_name in self.name_to_id:
num_id = self.name_to_id[var_name]
return self.id_to_type[num_id]
if self.parent_scope is None:
return NodeVarType.UNKNOWN
return self.parent_scope.get_var_type(var_name)
class AstVarEnv(object):
"""
A class maintains scopes and mapping from variable name to type.
"""
def __init__(self):
self.cur_scope = AstVarScope()
def enter_scope(self):
self.cur_scope = AstVarScope(parent_scope=self.cur_scope)
return self.cur_scope
def exit_scope(self):
assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\
"AstVarEnv when current scope doens't have parent scope."
self.cur_scope = self.cur_scope.parent_scope
return self.cur_scope
def set_var_type(self, var_name, node_var_type):
self.cur_scope.set_var_type(var_name, node_var_type)
def get_var_type(self, var_name):
return self.cur_scope.get_var_type(var_name)
def get_scope_var_type(self):
'''
Returns a dict mapping from variable name to type. Used for debug and
test.
'''
cur_scope_dict = {}
for name in self.cur_scope.name_to_id:
node_var_type = self.cur_scope.get_var_type(name)
cur_scope_dict[name] = node_var_type
return cur_scope_dict
class StaticAnalysisVisitor(object):
"""
A class that does static analysis
"""
def __init__(self, ast_root=None):
if ast_root is not None:
self.run(ast_root)
def run(self, ast_root):
self.node_wrapper_root = None
self.ancestor_wrappers = []
self.node_to_wrapper_map = {}
self.var_env = AstVarEnv()
self.dfs_visit(ast_root)
def dfs_visit(self, node):
# AST reuses some gast.nodes, such as Param node of expr_context
if node not in self.node_to_wrapper_map:
cur_wrapper = AstNodeWrapper(node)
self.node_to_wrapper_map[node] = cur_wrapper
else:
cur_wrapper = self.node_to_wrapper_map[node]
if self.node_wrapper_root is None:
self.node_wrapper_root = cur_wrapper
if len(self.ancestor_wrappers) != 0:
last_wrapper = self.ancestor_wrappers[-1]
last_wrapper.children.append(cur_wrapper)
cur_wrapper.parent = last_wrapper
self.ancestor_wrappers.append(cur_wrapper)
for child in gast.iter_child_nodes(node):
self.dfs_visit(child)
self.ancestor_wrappers.pop()
cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper)
return cur_wrapper.node_var_type
def get_node_wrapper_root(self):
return self.node_wrapper_root
def get_node_to_wrapper_map(self):
return self.node_to_wrapper_map
def get_var_env(self):
return self.var_env
def _get_node_var_type(self, cur_wrapper):
node = cur_wrapper.node
if isinstance(node, gast.Constant):
# singleton: None, True or False
if node.value is None:
return NodeVarType.NONE
if isinstance(node.value, bool):
return NodeVarType.BOOLEAN
if isinstance(node.value, int):
return NodeVarType.INT
if isinstance(node.value, float):
return NodeVarType.FLOAT
if isinstance(node.value, str):
return NodeVarType.STRING
if isinstance(node, gast.BoolOp):
return NodeVarType.BOOLEAN
if isinstance(node, gast.Compare):
return NodeVarType.BOOLEAN
if isinstance(node, gast.Dict):
return NodeVarType.DICT
if isinstance(node, gast.Set):
return NodeVarType.SET
if isinstance(node, gast.UnaryOp):
return self.node_to_wrapper_map[node.operand].node_var_type
if isinstance(node, gast.BinOp):
left_type = self.node_to_wrapper_map[node.left].node_var_type
right_type = self.node_to_wrapper_map[node.right].node_var_type
return NodeVarType.binary_op_output_type(left_type, right_type)
if isinstance(node, gast.Assign):
ret_type = self.node_to_wrapper_map[node.value].node_var_type
for target in node.targets:
if isinstance(target, gast.Name):
self.node_to_wrapper_map[target].node_var_type = ret_type
self.var_env.set_var_type(target.id, ret_type)
return ret_type
if isinstance(node, gast.Name):
if node.id == "None":
return NodeVarType.NONE
if node.id == "True" or node.id == "False":
return NodeVarType.BOOLEAN
return self.var_env.get_var_type(node.id)
if isinstance(node, gast.Call):
if is_dygraph_api(node):
api_name = node.func.attr
if api_name == "to_variable":
return NodeVarType.TENSOR
if is_numpy_api(node):
# In this simple version we assume numpy api returns nd-array
return NodeVarType.NUMPY_NDARRAY
return NodeVarType.STATEMENT
...@@ -14,18 +14,23 @@ ...@@ -14,18 +14,23 @@
from __future__ import print_function from __future__ import print_function
import ast import gast
import inspect import inspect
import numpy as np
import paddle.fluid as fluid
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
def func_to_test_1(a, b): def func_to_test1(a, b):
return a + b return a + b
def func_to_test_2(x): result_var_type1 = {}
def func_to_test2(x):
for i in range(10): for i in range(10):
x += i x += i
m = 3 m = 3
...@@ -37,6 +42,57 @@ def func_to_test_2(x): ...@@ -37,6 +42,57 @@ def func_to_test_2(x):
return x return x
result_var_type2 = {'m': NodeVarType.INT}
def func_to_test3():
a = 1
b = 3.0
c = a * b
d = True + c
e = a < b
f = 9 * (a * 4)
g = "dddy"
h = None
i = False
j = None + 1
result_var_type3 = {
'a': NodeVarType.INT,
'b': NodeVarType.FLOAT,
'c': NodeVarType.FLOAT,
'd': NodeVarType.FLOAT,
'e': NodeVarType.BOOLEAN,
'f': NodeVarType.INT,
'g': NodeVarType.STRING,
'h': NodeVarType.NONE,
'i': NodeVarType.BOOLEAN,
'j': NodeVarType.UNKNOWN
}
def func_to_test4():
with fluid.dygraph.guard():
a = np.random.uniform(0.1, 1, [1, 2])
b = 1 + a
c = fluid.dygraph.to_variable(b)
d = (c + 1) * 0.3
result_var_type4 = {
'a': NodeVarType.NUMPY_NDARRAY,
'b': NodeVarType.NUMPY_NDARRAY,
'c': NodeVarType.TENSOR,
'd': NodeVarType.TENSOR
}
test_funcs = [func_to_test1, func_to_test2, func_to_test3, func_to_test4]
result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4
]
class TestStaticAnalysis(unittest.TestCase): class TestStaticAnalysis(unittest.TestCase):
def _check_wrapper(self, wrapper, node_to_wrapper_map): def _check_wrapper(self, wrapper, node_to_wrapper_map):
self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper) self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper)
...@@ -44,7 +100,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -44,7 +100,7 @@ class TestStaticAnalysis(unittest.TestCase):
self.assertTrue(wrapper in wrapper.parent.children) self.assertTrue(wrapper in wrapper.parent.children)
children_ast_nodes = [ children_ast_nodes = [
child for child in ast.iter_child_nodes(wrapper.node) child for child in gast.iter_child_nodes(wrapper.node)
] ]
self.assertEqual(len(wrapper.children), len(children_ast_nodes)) self.assertEqual(len(wrapper.children), len(children_ast_nodes))
for child in wrapper.children: for child in wrapper.children:
...@@ -52,15 +108,30 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -52,15 +108,30 @@ class TestStaticAnalysis(unittest.TestCase):
self._check_wrapper(child, node_to_wrapper_map) self._check_wrapper(child, node_to_wrapper_map)
def test_construct_node_wrapper(self): def test_construct_node_wrapper(self):
for func in [func_to_test_1, func_to_test_2]: for func in test_funcs:
test_source_code = inspect.getsource(func) test_source_code = inspect.getsource(func)
ast_root = ast.parse(test_source_code) ast_root = gast.parse(test_source_code)
visitor = StaticAnalysisVisitor(ast_root) visitor = StaticAnalysisVisitor(ast_root)
wrapper_root = visitor.get_node_wrapper_root() wrapper_root = visitor.get_node_wrapper_root()
node_to_wrapper_map = visitor.get_node_to_wrapper_map() node_to_wrapper_map = visitor.get_node_to_wrapper_map()
self._check_wrapper(wrapper_root, node_to_wrapper_map) self._check_wrapper(wrapper_root, node_to_wrapper_map)
def test_var_env(self):
for i in range(4):
func = test_funcs[i]
var_type = result_var_type[i]
test_source_code = inspect.getsource(func)
ast_root = gast.parse(test_source_code)
print(gast.dump(ast_root))
visitor = StaticAnalysisVisitor(ast_root)
var_env = visitor.get_var_env()
scope_var_type = var_env.get_scope_var_type()
self.assertEqual(len(scope_var_type), len(var_type))
for name in scope_var_type:
print("Test var name %s" % (name))
self.assertTrue(name in var_type)
self.assertEqual(scope_var_type[name], var_type[name])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,3 +18,4 @@ decorator ...@@ -18,3 +18,4 @@ decorator
prettytable prettytable
objgraph objgraph
gast gast
astor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册