未验证 提交 db8425ec 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]support Python3 type annotation (#36544)

* Support Py3 type annotations in @to_static

* support type hint for args in func

* support type hint assign

* if annotation and value(Constant) are diffent type, we use value type

* polish type_from_annotation()

* code format

* code format

* remove useless commentary

* fix review
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 0590277a
...@@ -208,10 +208,10 @@ class ErrorData(object): ...@@ -208,10 +208,10 @@ class ErrorData(object):
message_lines.append("") message_lines.append("")
# Add paddle traceback after user code traceback # Add paddle traceback after user code traceback
paddle_traceback_start_idnex = user_code_traceback_index[ paddle_traceback_start_index = user_code_traceback_index[
-1] + 1 if user_code_traceback_index else 0 -1] + 1 if user_code_traceback_index else 0
for filepath, lineno, funcname, code in self.origin_traceback[ for filepath, lineno, funcname, code in self.origin_traceback[
paddle_traceback_start_idnex:]: paddle_traceback_start_index:]:
traceback_frame = TraceBackFrame( traceback_frame = TraceBackFrame(
Location(filepath, lineno), funcname, code) Location(filepath, lineno), funcname, code)
message_lines.append(traceback_frame.formated_message()) message_lines.append(traceback_frame.formated_message())
...@@ -305,10 +305,10 @@ class ErrorData(object): ...@@ -305,10 +305,10 @@ class ErrorData(object):
error_frame.append("") error_frame.append("")
# Add paddle traceback after user code traceback # Add paddle traceback after user code traceback
paddle_traceback_start_idnex = user_code_traceback_index[ paddle_traceback_start_index = user_code_traceback_index[
-1] + 1 if user_code_traceback_index else 0 -1] + 1 if user_code_traceback_index else 0
for filepath, lineno, funcname, code in error_traceback[ for filepath, lineno, funcname, code in error_traceback[
paddle_traceback_start_idnex:]: paddle_traceback_start_index:]:
traceback_frame = TraceBackFrame( traceback_frame = TraceBackFrame(
Location(filepath, lineno), funcname, code) Location(filepath, lineno), funcname, code)
error_frame.append(traceback_frame.formated_message()) error_frame.append(traceback_frame.formated_message())
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
from __future__ import print_function from __future__ import print_function
from paddle.utils import gast from paddle.utils import gast
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list from .logging_utils import warn
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list, ast_to_source_code
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
...@@ -57,6 +58,15 @@ class NodeVarType(object): ...@@ -57,6 +58,15 @@ class NodeVarType(object):
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent. # If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES} TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}
Annotation_map = {
"Tensor": TENSOR,
"paddle.Tensor": TENSOR,
"int": INT,
"float": FLOAT,
"bool": BOOLEAN,
"str": STRING
}
@staticmethod @staticmethod
def binary_op_output_type(in_type1, in_type2): def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2: if in_type1 == in_type2:
...@@ -83,6 +93,16 @@ class NodeVarType(object): ...@@ -83,6 +93,16 @@ class NodeVarType(object):
return NodeVarType.UNKNOWN return NodeVarType.UNKNOWN
return max(in_type1, in_type2) return max(in_type1, in_type2)
@staticmethod
def type_from_annotation(annotation):
annotation_str = ast_to_source_code(annotation).strip()
if annotation_str in NodeVarType.Annotation_map:
return NodeVarType.Annotation_map[annotation_str]
# raise warning if not found
warn("Currently we don't support annotation: %s" % annotation_str)
return NodeVarType.UNKNOWN
class AstNodeWrapper(object): class AstNodeWrapper(object):
""" """
...@@ -316,6 +336,18 @@ class StaticAnalysisVisitor(object): ...@@ -316,6 +336,18 @@ class StaticAnalysisVisitor(object):
self.var_env.set_var_type(target.id, ret_type) self.var_env.set_var_type(target.id, ret_type)
return ret_type return ret_type
if isinstance(node, gast.AnnAssign):
# TODO(0x45f): To determine whether need to support assignment statements
# like `self.x: float = 2.1`.
ret_type = {NodeVarType.type_from_annotation(node.annotation)}
# if annotation and value(Constant) are diffent type, we use value type
if node.value:
ret_type = self.node_to_wrapper_map[node.value].node_var_type
if isinstance(node.target, gast.Name):
self.node_to_wrapper_map[node.target].node_var_type = ret_type
self.var_env.set_var_type(node.target.id, ret_type)
return ret_type
if isinstance(node, gast.Name): if isinstance(node, gast.Name):
if node.id == "None": if node.id == "None":
return {NodeVarType.NONE} return {NodeVarType.NONE}
...@@ -325,21 +357,8 @@ class StaticAnalysisVisitor(object): ...@@ -325,21 +357,8 @@ class StaticAnalysisVisitor(object):
parent_node_wrapper = cur_wrapper.parent parent_node_wrapper = cur_wrapper.parent
if parent_node_wrapper and isinstance(parent_node_wrapper.node, if parent_node_wrapper and isinstance(parent_node_wrapper.node,
gast.arguments): gast.arguments):
parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(
parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(
defaults_node)
# Add node with identified type into cur_env. return self._get_func_argument_type(parent_node_wrapper, node)
self.var_env.set_var_type(node.id, var_type)
return var_type
return self.var_env.get_var_type(node.id) return self.var_env.get_var_type(node.id)
...@@ -373,3 +392,42 @@ class StaticAnalysisVisitor(object): ...@@ -373,3 +392,42 @@ class StaticAnalysisVisitor(object):
return {NodeVarType.TENSOR} return {NodeVarType.TENSOR}
return {NodeVarType.STATEMENT} return {NodeVarType.STATEMENT}
def _get_func_argument_type(self, parent_node_wrapper, node):
"""
Returns type information by parsing annotation or default values.
For example:
1. parse by default values.
foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR
2. parse by Py3 type annotation.
foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR
3. parse by type annotation and default values.
foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR
NOTE: Currently, we only support Tensor, int, bool, float, str et.al.
Other complicate types will be supported later.
"""
assert isinstance(node, gast.Name)
parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if node.annotation is not None:
var_type = {NodeVarType.type_from_annotation(node.annotation)}
self.var_env.set_var_type(node.id, var_type)
# if annotation and value(Constant) are diffent type, we use value type
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(defaults_node)
# Add node with identified type into cur_env.
self.var_env.set_var_type(node.id, var_type)
return var_type
...@@ -520,7 +520,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -520,7 +520,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
def _inject_import_statements(): def _inject_import_statements():
import_statements = [ import_statements = [
"import paddle", "import paddle.fluid as fluid", "from typing import *", "import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "from typing import *",
"import numpy as np" "import numpy as np"
] ]
return '\n'.join(import_statements) + '\n' return '\n'.join(import_statements) + '\n'
......
...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func self.func = simple_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6, 7] self.static_abs_lineno_list = [6, 7, 8]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 3 self.line_num = 3
...@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): ...@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func self.func = nested_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [5, 7, 8, 9, 10] self.static_abs_lineno_list = [6, 8, 9, 10, 11]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 5 self.line_num = 5
...@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): ...@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func self.func = decorated_func
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6] self.static_abs_lineno_list = [6, 7]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
...@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): ...@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2 self.func = decorated_func2
def set_static_lineno(self): def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6] self.static_abs_lineno_list = [6, 7]
def set_dygraph_info(self): def set_dygraph_info(self):
self.line_num = 2 self.line_num = 2
......
...@@ -57,6 +57,8 @@ def func_to_test3(): ...@@ -57,6 +57,8 @@ def func_to_test3():
h = None h = None
i = False i = False
j = None + 1 j = None + 1
k: float = 1.0
l: paddle.Tensor = paddle.to_tensor([1, 2])
result_var_type3 = { result_var_type3 = {
...@@ -69,7 +71,9 @@ result_var_type3 = { ...@@ -69,7 +71,9 @@ result_var_type3 = {
'g': {NodeVarType.STRING}, 'g': {NodeVarType.STRING},
'h': {NodeVarType.NONE}, 'h': {NodeVarType.NONE},
'i': {NodeVarType.BOOLEAN}, 'i': {NodeVarType.BOOLEAN},
'j': {NodeVarType.UNKNOWN} 'j': {NodeVarType.UNKNOWN},
'k': {NodeVarType.FLOAT},
'l': {NodeVarType.PADDLE_RETURN_TYPES}
} }
...@@ -139,13 +143,25 @@ result_var_type6 = { ...@@ -139,13 +143,25 @@ result_var_type6 = {
'add': {NodeVarType.INT} 'add': {NodeVarType.INT}
} }
def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'):
a = True
result_var_type7 = {
'a': {NodeVarType.BOOLEAN},
'b': {NodeVarType.FLOAT},
'c': {NodeVarType.TENSOR},
'd': {NodeVarType.STRING}
}
test_funcs = [ test_funcs = [
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5, func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5,
func_to_test6 func_to_test6, func_to_test7
] ]
result_var_type = [ result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4, result_var_type1, result_var_type2, result_var_type3, result_var_type4,
result_var_type5, result_var_type6 result_var_type5, result_var_type6, result_var_type7
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册