未验证 提交 9474d140 编写于 作者: A Aurelius84 提交者: GitHub

Support Parameter type determination in StaticAnalysis (#23302)

* Support Parameter type determination test=develop
上级 20eed540
......@@ -16,7 +16,7 @@ from __future__ import print_function
import gast
import warnings
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
......@@ -260,9 +260,9 @@ class StaticAnalysisVisitor(object):
def get_var_env(self):
return self.var_env
def _get_node_var_type(self, cur_wrapper):
node = cur_wrapper.node
if isinstance(node, gast.Constant):
def _get_constant_node_type(self, node):
assert isinstance(node, gast.Constant), \
"Type of input node should be gast.Constant, but received %s" % type(node)
# singleton: None, True or False
if node.value is None:
return {NodeVarType.NONE}
......@@ -275,6 +275,13 @@ class StaticAnalysisVisitor(object):
if isinstance(node.value, str):
return {NodeVarType.STRING}
return {NodeVarType.UNKNOWN}
def _get_node_var_type(self, cur_wrapper):
node = cur_wrapper.node
if isinstance(node, gast.Constant):
return self._get_constant_node_type(node)
if isinstance(node, gast.BoolOp):
return {NodeVarType.BOOLEAN}
if isinstance(node, gast.Compare):
......@@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object):
if isinstance(node, gast.Name):
if node.id == "None":
return {NodeVarType.NONE}
if node.id == "True" or node.id == "False":
if node.id in {"True", "False"}:
return {NodeVarType.BOOLEAN}
# If node is child of functionDef.arguments
parent_node_wrapper = cur_wrapper.parent
if parent_node_wrapper and isinstance(parent_node_wrapper.node,
gast.arguments):
parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(
parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(
defaults_node)
# Add node with identified type into cur_env.
self.var_env.set_var_type(node.id, var_type)
return var_type
return self.var_env.get_var_type(node.id)
if isinstance(node, gast.Return):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,7 @@ import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.dygraph_to_static import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static import NodeVarType, StaticAnalysisVisitor
def func_to_test1(a, b):
......@@ -117,12 +117,34 @@ result_var_type5 = {
'inner_unknown_func': {NodeVarType.UNKNOWN},
}
def func_to_test6(x, y=1):
i = fluid.dygraph.to_variable(x)
def add(x, y):
return x + y
while x < 10:
i = add(i, x)
x = x + y
return i
result_var_type6 = {
'i': {NodeVarType.INT},
'x': {NodeVarType.INT},
'y': {NodeVarType.INT},
'add': {NodeVarType.INT}
}
test_funcs = [
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5,
func_to_test6
]
result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4,
result_var_type5
result_var_type5, result_var_type6
]
......@@ -150,8 +172,8 @@ class TestStaticAnalysis(unittest.TestCase):
self._check_wrapper(wrapper_root, node_to_wrapper_map)
def test_var_env(self):
for i in range(5):
func = test_funcs[i]
for i, func in enumerate(test_funcs):
var_type = result_var_type[i]
test_source_code = inspect.getsource(func)
ast_root = gast.parse(test_source_code)
......@@ -164,6 +186,7 @@ class TestStaticAnalysis(unittest.TestCase):
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
scope_var_type = var_env.get_scope_var_type()
print(scope_var_type)
self.assertEqual(len(scope_var_type), len(var_type))
for name in scope_var_type:
print("Test var name %s" % (name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册