未验证 提交 9474d140 编写于 作者: A Aurelius84 提交者: GitHub

Support Parameter type determination in StaticAnalysis (#23302)

* Support Parameter type determination test=develop
上级 20eed540
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import gast import gast
import warnings import warnings
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
...@@ -260,20 +260,27 @@ class StaticAnalysisVisitor(object): ...@@ -260,20 +260,27 @@ class StaticAnalysisVisitor(object):
def get_var_env(self): def get_var_env(self):
return self.var_env return self.var_env
def _get_constant_node_type(self, node):
assert isinstance(node, gast.Constant), \
"Type of input node should be gast.Constant, but received %s" % type(node)
# singleton: None, True or False
if node.value is None:
return {NodeVarType.NONE}
if isinstance(node.value, bool):
return {NodeVarType.BOOLEAN}
if isinstance(node.value, int):
return {NodeVarType.INT}
if isinstance(node.value, float):
return {NodeVarType.FLOAT}
if isinstance(node.value, str):
return {NodeVarType.STRING}
return {NodeVarType.UNKNOWN}
def _get_node_var_type(self, cur_wrapper): def _get_node_var_type(self, cur_wrapper):
node = cur_wrapper.node node = cur_wrapper.node
if isinstance(node, gast.Constant): if isinstance(node, gast.Constant):
# singleton: None, True or False return self._get_constant_node_type(node)
if node.value is None:
return {NodeVarType.NONE}
if isinstance(node.value, bool):
return {NodeVarType.BOOLEAN}
if isinstance(node.value, int):
return {NodeVarType.INT}
if isinstance(node.value, float):
return {NodeVarType.FLOAT}
if isinstance(node.value, str):
return {NodeVarType.STRING}
if isinstance(node, gast.BoolOp): if isinstance(node, gast.BoolOp):
return {NodeVarType.BOOLEAN} return {NodeVarType.BOOLEAN}
...@@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object): ...@@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object):
if isinstance(node, gast.Name): if isinstance(node, gast.Name):
if node.id == "None": if node.id == "None":
return {NodeVarType.NONE} return {NodeVarType.NONE}
if node.id == "True" or node.id == "False": if node.id in {"True", "False"}:
return {NodeVarType.BOOLEAN} return {NodeVarType.BOOLEAN}
# If node is child of functionDef.arguments
parent_node_wrapper = cur_wrapper.parent
if parent_node_wrapper and isinstance(parent_node_wrapper.node,
gast.arguments):
parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(
parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(
defaults_node)
# Add node with identified type into cur_env.
self.var_env.set_var_type(node.id, var_type)
return var_type
return self.var_env.get_var_type(node.id) return self.var_env.get_var_type(node.id)
if isinstance(node, gast.Return): if isinstance(node, gast.Return):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static import NodeVarType, StaticAnalysisVisitor
def func_to_test1(a, b): def func_to_test1(a, b):
...@@ -117,12 +117,34 @@ result_var_type5 = { ...@@ -117,12 +117,34 @@ result_var_type5 = {
'inner_unknown_func': {NodeVarType.UNKNOWN}, 'inner_unknown_func': {NodeVarType.UNKNOWN},
} }
def func_to_test6(x, y=1):
i = fluid.dygraph.to_variable(x)
def add(x, y):
return x + y
while x < 10:
i = add(i, x)
x = x + y
return i
result_var_type6 = {
'i': {NodeVarType.INT},
'x': {NodeVarType.INT},
'y': {NodeVarType.INT},
'add': {NodeVarType.INT}
}
test_funcs = [ test_funcs = [
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5 func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5,
func_to_test6
] ]
result_var_type = [ result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4, result_var_type1, result_var_type2, result_var_type3, result_var_type4,
result_var_type5 result_var_type5, result_var_type6
] ]
...@@ -150,8 +172,8 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -150,8 +172,8 @@ class TestStaticAnalysis(unittest.TestCase):
self._check_wrapper(wrapper_root, node_to_wrapper_map) self._check_wrapper(wrapper_root, node_to_wrapper_map)
def test_var_env(self): def test_var_env(self):
for i in range(5):
func = test_funcs[i] for i, func in enumerate(test_funcs):
var_type = result_var_type[i] var_type = result_var_type[i]
test_source_code = inspect.getsource(func) test_source_code = inspect.getsource(func)
ast_root = gast.parse(test_source_code) ast_root = gast.parse(test_source_code)
...@@ -164,6 +186,7 @@ class TestStaticAnalysis(unittest.TestCase): ...@@ -164,6 +186,7 @@ class TestStaticAnalysis(unittest.TestCase):
var_env.cur_scope = var_env.cur_scope.sub_scopes[0] var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
scope_var_type = var_env.get_scope_var_type() scope_var_type = var_env.get_scope_var_type()
print(scope_var_type)
self.assertEqual(len(scope_var_type), len(var_type)) self.assertEqual(len(scope_var_type), len(var_type))
for name in scope_var_type: for name in scope_var_type:
print("Test var name %s" % (name)) print("Test var name %s" % (name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册