From db8425ec9bb41898a95ebaa361260ebca98af78a Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 3 Nov 2021 15:15:51 +0800 Subject: [PATCH] [Dy2stat]support Python3 type annotation (#36544) * Support Py3 type annotations in @to_static * support type hint for args in func * support type hint assign * if annotation and value(Constant) are diffent type, we use value type * polish type_from_annotation() * code format * code format * remove useless commentary * fix review Co-authored-by: Aurelius84 --- .../fluid/dygraph/dygraph_to_static/error.py | 8 +- .../dygraph_to_static/static_analysis.py | 90 +++++++++++++++---- .../fluid/dygraph/dygraph_to_static/utils.py | 3 +- .../dygraph_to_static/test_origin_info.py | 8 +- .../dygraph_to_static/test_static_analysis.py | 22 ++++- 5 files changed, 103 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 008070fcead..69ec89a5af6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -208,10 +208,10 @@ class ErrorData(object): message_lines.append("") # Add paddle traceback after user code traceback - paddle_traceback_start_idnex = user_code_traceback_index[ + paddle_traceback_start_index = user_code_traceback_index[ -1] + 1 if user_code_traceback_index else 0 for filepath, lineno, funcname, code in self.origin_traceback[ - paddle_traceback_start_idnex:]: + paddle_traceback_start_index:]: traceback_frame = TraceBackFrame( Location(filepath, lineno), funcname, code) message_lines.append(traceback_frame.formated_message()) @@ -305,10 +305,10 @@ class ErrorData(object): error_frame.append("") # Add paddle traceback after user code traceback - paddle_traceback_start_idnex = user_code_traceback_index[ + paddle_traceback_start_index = user_code_traceback_index[ -1] + 1 if user_code_traceback_index else 0 for filepath, lineno, funcname, code in error_traceback[ - paddle_traceback_start_idnex:]: + paddle_traceback_start_index:]: traceback_frame = TraceBackFrame( Location(filepath, lineno), funcname, code) error_frame.append(traceback_frame.formated_message()) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index ce5f50137b7..45a42d481b5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -15,7 +15,8 @@ from __future__ import print_function from paddle.utils import gast -from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list +from .logging_utils import warn +from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list, ast_to_source_code __all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] @@ -57,6 +58,15 @@ class NodeVarType(object): # If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent. TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES} + Annotation_map = { + "Tensor": TENSOR, + "paddle.Tensor": TENSOR, + "int": INT, + "float": FLOAT, + "bool": BOOLEAN, + "str": STRING + } + @staticmethod def binary_op_output_type(in_type1, in_type2): if in_type1 == in_type2: @@ -83,6 +93,16 @@ class NodeVarType(object): return NodeVarType.UNKNOWN return max(in_type1, in_type2) + @staticmethod + def type_from_annotation(annotation): + annotation_str = ast_to_source_code(annotation).strip() + if annotation_str in NodeVarType.Annotation_map: + return NodeVarType.Annotation_map[annotation_str] + + # raise warning if not found + warn("Currently we don't support annotation: %s" % annotation_str) + return NodeVarType.UNKNOWN + class AstNodeWrapper(object): """ @@ -316,6 +336,18 @@ class StaticAnalysisVisitor(object): self.var_env.set_var_type(target.id, ret_type) return ret_type + if isinstance(node, gast.AnnAssign): + # TODO(0x45f): To determine whether need to support assignment statements + # like `self.x: float = 2.1`. + ret_type = {NodeVarType.type_from_annotation(node.annotation)} + # if annotation and value(Constant) are diffent type, we use value type + if node.value: + ret_type = self.node_to_wrapper_map[node.value].node_var_type + if isinstance(node.target, gast.Name): + self.node_to_wrapper_map[node.target].node_var_type = ret_type + self.var_env.set_var_type(node.target.id, ret_type) + return ret_type + if isinstance(node, gast.Name): if node.id == "None": return {NodeVarType.NONE} @@ -325,21 +357,8 @@ class StaticAnalysisVisitor(object): parent_node_wrapper = cur_wrapper.parent if parent_node_wrapper and isinstance(parent_node_wrapper.node, gast.arguments): - parent_node = parent_node_wrapper.node - var_type = {NodeVarType.UNKNOWN} - if parent_node.defaults: - index = index_in_list(parent_node.args, node) - args_len = len(parent_node.args) - if index != -1 and args_len - index <= len( - parent_node.defaults): - defaults_node = parent_node.defaults[index - args_len] - if isinstance(defaults_node, gast.Constant): - var_type = self._get_constant_node_type( - defaults_node) - - # Add node with identified type into cur_env. - self.var_env.set_var_type(node.id, var_type) - return var_type + + return self._get_func_argument_type(parent_node_wrapper, node) return self.var_env.get_var_type(node.id) @@ -373,3 +392,42 @@ class StaticAnalysisVisitor(object): return {NodeVarType.TENSOR} return {NodeVarType.STATEMENT} + + def _get_func_argument_type(self, parent_node_wrapper, node): + """ + Returns type information by parsing annotation or default values. + + For example: + 1. parse by default values. + foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR + + 2. parse by Py3 type annotation. + foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR + + 3. parse by type annotation and default values. + foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR + + NOTE: Currently, we only support Tensor, int, bool, float, str et.al. + Other complicate types will be supported later. + """ + assert isinstance(node, gast.Name) + + parent_node = parent_node_wrapper.node + var_type = {NodeVarType.UNKNOWN} + if node.annotation is not None: + var_type = {NodeVarType.type_from_annotation(node.annotation)} + self.var_env.set_var_type(node.id, var_type) + + # if annotation and value(Constant) are diffent type, we use value type + if parent_node.defaults: + index = index_in_list(parent_node.args, node) + args_len = len(parent_node.args) + if index != -1 and args_len - index <= len(parent_node.defaults): + defaults_node = parent_node.defaults[index - args_len] + if isinstance(defaults_node, gast.Constant): + var_type = self._get_constant_node_type(defaults_node) + + # Add node with identified type into cur_env. + self.var_env.set_var_type(node.id, var_type) + + return var_type diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 650857eefb3..4da898d7441 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -520,7 +520,8 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): def _inject_import_statements(): import_statements = [ - "import paddle", "import paddle.fluid as fluid", "from typing import *", + "import paddle", "from paddle import Tensor", + "import paddle.fluid as fluid", "from typing import *", "import numpy as np" ] return '\n'.join(import_statements) + '\n' diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py index 016a1b3b588..e3d34184a38 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): self.func = simple_func def set_static_lineno(self): - self.static_abs_lineno_list = [5, 6, 7] + self.static_abs_lineno_list = [6, 7, 8] def set_dygraph_info(self): self.line_num = 3 @@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): self.func = nested_func def set_static_lineno(self): - self.static_abs_lineno_list = [5, 7, 8, 9, 10] + self.static_abs_lineno_list = [6, 8, 9, 10, 11] def set_dygraph_info(self): self.line_num = 5 @@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): self.func = decorated_func def set_static_lineno(self): - self.static_abs_lineno_list = [5, 6] + self.static_abs_lineno_list = [6, 7] def set_dygraph_info(self): self.line_num = 2 @@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): self.func = decorated_func2 def set_static_lineno(self): - self.static_abs_lineno_list = [5, 6] + self.static_abs_lineno_list = [6, 7] def set_dygraph_info(self): self.line_num = 2 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py index 7f6d6cf1f3b..afccaca6938 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py @@ -57,6 +57,8 @@ def func_to_test3(): h = None i = False j = None + 1 + k: float = 1.0 + l: paddle.Tensor = paddle.to_tensor([1, 2]) result_var_type3 = { @@ -69,7 +71,9 @@ result_var_type3 = { 'g': {NodeVarType.STRING}, 'h': {NodeVarType.NONE}, 'i': {NodeVarType.BOOLEAN}, - 'j': {NodeVarType.UNKNOWN} + 'j': {NodeVarType.UNKNOWN}, + 'k': {NodeVarType.FLOAT}, + 'l': {NodeVarType.PADDLE_RETURN_TYPES} } @@ -139,13 +143,25 @@ result_var_type6 = { 'add': {NodeVarType.INT} } + +def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'): + a = True + + +result_var_type7 = { + 'a': {NodeVarType.BOOLEAN}, + 'b': {NodeVarType.FLOAT}, + 'c': {NodeVarType.TENSOR}, + 'd': {NodeVarType.STRING} +} + test_funcs = [ func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5, - func_to_test6 + func_to_test6, func_to_test7 ] result_var_type = [ result_var_type1, result_var_type2, result_var_type3, result_var_type4, - result_var_type5, result_var_type6 + result_var_type5, result_var_type6, result_var_type7 ] -- GitLab