From d7ad0e427f10d02512bd4136abb6a25eda9befbc Mon Sep 17 00:00:00 2001 From: haozi <64006169+NotHaozi@users.noreply.github.com> Date: Fri, 2 Jun 2023 14:36:00 +0800 Subject: [PATCH] dy2static: Remove the dependency of AstNodeWrapper (#54237) * [IR] delete_wrapper * [IR] modify the ast_transformer core logic * [fix] modify the test_origin_info --- python/paddle/jit/dy2static/__init__.py | 2 +- .../jit/dy2static/assert_transformer.py | 9 ++------ .../paddle/jit/dy2static/ast_transformer.py | 22 ++++++++----------- .../paddle/jit/dy2static/base_transformer.py | 5 ++--- .../jit/dy2static/basic_api_transformer.py | 21 +++++------------- .../dy2static/break_continue_transformer.py | 10 ++++----- .../paddle/jit/dy2static/call_transformer.py | 9 ++------ .../paddle/jit/dy2static/cast_transformer.py | 10 +++------ .../dy2static/create_variable_transformer.py | 9 ++------ .../jit/dy2static/decorator_transformer.py | 9 ++------ .../jit/dy2static/early_return_transformer.py | 9 ++------ .../jit/dy2static/ifelse_transformer.py | 9 ++------ .../jit/dy2static/logical_transformer.py | 5 ++--- .../paddle/jit/dy2static/loop_transformer.py | 12 ++++------ .../jit/dy2static/program_translator.py | 14 ++++++------ .../jit/dy2static/return_transformer.py | 5 ++--- .../jit/dy2static/tensor_shape_transformer.py | 9 ++------ .../jit/dy2static/tensorhook_transformer.py | 4 ++-- .../jit/dy2static/typehint_transformer.py | 10 ++------- test/dygraph_to_static/test_origin_info.py | 2 +- 20 files changed, 58 insertions(+), 127 deletions(-) diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index bc91a4c1674..136ab6dab7d 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -34,6 +34,6 @@ from .convert_operators import convert_shape_compare # noqa: F401 from .assert_transformer import AssertTransformer from .ast_transformer import DygraphToStaticAst from .program_translator import convert_to_static -from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from .static_analysis import NodeVarType, StaticAnalysisVisitor __all__ = [] diff --git a/python/paddle/jit/dy2static/assert_transformer.py b/python/paddle/jit/dy2static/assert_transformer.py index f77a3f87f98..acf17b58618 100644 --- a/python/paddle/jit/dy2static/assert_transformer.py +++ b/python/paddle/jit/dy2static/assert_transformer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.jit.dy2static.static_analysis import AstNodeWrapper from paddle.jit.dy2static.utils import ast_to_source_code from paddle.utils import gast @@ -26,12 +25,8 @@ class AssertTransformer(BaseTransformer): A class transforms python assert to convert_assert. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of AssertTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): self.visit(self.root) diff --git a/python/paddle/jit/dy2static/ast_transformer.py b/python/paddle/jit/dy2static/ast_transformer.py index 66d334dc7fc..bdb3e5b59cc 100644 --- a/python/paddle/jit/dy2static/ast_transformer.py +++ b/python/paddle/jit/dy2static/ast_transformer.py @@ -36,7 +36,6 @@ from .ifelse_transformer import IfElseTransformer from .logical_transformer import LogicalTransformer from .loop_transformer import LoopTransformer from .return_transformer import ReturnTransformer -from .static_analysis import StaticAnalysisVisitor from .tensor_shape_transformer import TensorShapeTransformer from .tensorhook_transformer import RegisterHookTransformer from .typehint_transformer import TypeHintTransformer @@ -69,28 +68,25 @@ class DygraphToStaticAst(BaseTransformer): self.translator_logger = logging_utils.TranslatorLogger() def get_static_ast(self, root): - # save root for some analysis may need global AST self.root = root - self.static_analysis_visitor = StaticAnalysisVisitor(root) - self.static_analysis_root = ( - self.static_analysis_visitor.get_node_wrapper_root() - ) self.decorate_func_name = None - self.transfer_from_node_type(self.static_analysis_root) - return self.static_analysis_root - def _apply(self, transformer, node_wrapper, log_level): - transformer(node_wrapper).transform() + # inplace transfer + self.transfer_from_node_type(self.root) + return self.root + + def _apply(self, transformer, node, log_level): + transformer(node).transform() self.translator_logger.log_transformed_code( log_level, self.root, transformer.__name__ ) - def transfer_from_node_type(self, node_wrapper): + def transfer_from_node_type(self, node): self.translator_logger.log( 1, f"Source code: \n{ast_to_source_code(self.root)}" ) # Generic transformation - self.visit(node_wrapper.node) + self.visit(node) transformers = [ RegisterHookTransformer, @@ -114,7 +110,7 @@ class DygraphToStaticAst(BaseTransformer): apply_optimization(transformers) for index, transformer in enumerate(transformers): - self._apply(transformer, node_wrapper, log_level=index + 1) + self._apply(transformer, node, log_level=index + 1) self.translator_logger.log_transformed_code( logging_utils.LOG_AllTransformer, self.root, "All Transformers" diff --git a/python/paddle/jit/dy2static/base_transformer.py b/python/paddle/jit/dy2static/base_transformer.py index c019f87dee0..cddea923760 100644 --- a/python/paddle/jit/dy2static/base_transformer.py +++ b/python/paddle/jit/dy2static/base_transformer.py @@ -139,9 +139,8 @@ class ForLoopTuplePreTransformer(BaseTransformer): >>> body """ - def __init__(self, wrapper_root): - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): self.visit(self.root) diff --git a/python/paddle/jit/dy2static/basic_api_transformer.py b/python/paddle/jit/dy2static/basic_api_transformer.py index a2c9823a353..f188df92cd9 100644 --- a/python/paddle/jit/dy2static/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/basic_api_transformer.py @@ -18,7 +18,6 @@ from paddle.utils import gast from . import utils from .base_transformer import BaseTransformer -from .static_analysis import AstNodeWrapper __all__ = [] @@ -28,13 +27,8 @@ class BasicApiTransformer(BaseTransformer): Class to transform basic API from dygraph to static graph. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." - - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root self.class_node_dict = {} def transform(self): @@ -43,7 +37,7 @@ class BasicApiTransformer(BaseTransformer): attribute_transformer = AttributeJstTransformer(self.root) attribute_transformer.transform() self.visit(self.root) - return self.wrapper_root + return self.root def visit_Assign(self, node): if self._update_class_node_dict(node): @@ -138,13 +132,8 @@ class NameloadJstTransformer(BaseTransformer): NOTE: we only deal with ctx=Load() case. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." - - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): self.visit(self.root) diff --git a/python/paddle/jit/dy2static/break_continue_transformer.py b/python/paddle/jit/dy2static/break_continue_transformer.py index 728fc02eff1..6e1199bf0ed 100644 --- a/python/paddle/jit/dy2static/break_continue_transformer.py +++ b/python/paddle/jit/dy2static/break_continue_transformer.py @@ -111,11 +111,10 @@ class BreakContinueTransformer(BaseNodeVisitor): In general, we recommend to inheriting NodeTransformer to modify node! """ - def __init__(self, wrapper_root): + def __init__(self, root): super().__init__() - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + self.root = root def transform(self): self.visit(self.root) @@ -336,11 +335,10 @@ class BreakTransformOptimizer(BaseNodeVisitor): usually brings very heavy overhead. """ - def __init__(self, wrapper_root): + def __init__(self, root): super().__init__() - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + self.root = root def transform(self): self.visit(self.root) diff --git a/python/paddle/jit/dy2static/call_transformer.py b/python/paddle/jit/dy2static/call_transformer.py index e4418a62206..7ed0a4681bc 100644 --- a/python/paddle/jit/dy2static/call_transformer.py +++ b/python/paddle/jit/dy2static/call_transformer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.jit.dy2static.static_analysis import AstNodeWrapper from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api from paddle.utils import gast @@ -29,12 +28,8 @@ class CallTransformer(BaseTransformer): This class transforms function calls into Static Graph Ast. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of CallTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def _no_need_convert_call(self, node): """ diff --git a/python/paddle/jit/dy2static/cast_transformer.py b/python/paddle/jit/dy2static/cast_transformer.py index da169786cd3..c556f374104 100644 --- a/python/paddle/jit/dy2static/cast_transformer.py +++ b/python/paddle/jit/dy2static/cast_transformer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.jit.dy2static.static_analysis import AstNodeWrapper from paddle.jit.dy2static.utils import ast_to_source_code from paddle.utils import gast @@ -26,15 +25,12 @@ class CastTransformer(BaseTransformer): This class transforms type casting into Static Graph Ast. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of CastTransformer." - self._root = wrapper_root.node + def __init__(self, root): + self.root = root self._castable_type = {'bool', 'int', 'float'} def transform(self): - self.visit(self._root) + self.visit(self.root) def visit_Call(self, node): self.generic_visit(node) diff --git a/python/paddle/jit/dy2static/create_variable_transformer.py b/python/paddle/jit/dy2static/create_variable_transformer.py index 4290b589631..f0d5583834b 100644 --- a/python/paddle/jit/dy2static/create_variable_transformer.py +++ b/python/paddle/jit/dy2static/create_variable_transformer.py @@ -13,7 +13,6 @@ # limitations under the License. from .base_transformer import BaseTransformer -from .static_analysis import AstNodeWrapper from .utils import FunctionNameLivenessAnalysis from .variable_trans_func import create_undefined_var @@ -23,12 +22,8 @@ __all__ = [] class CreateVariableTransformer(BaseTransformer): """ """ - def __init__(self, wrapper_root): - assert isinstance(wrapper_root, AstNodeWrapper), ( - "Type of input node should be AstNodeWrapper, but received %s ." - % type(wrapper_root) - ) - self.root = wrapper_root.node + def __init__(self, root): + self.root = root FunctionNameLivenessAnalysis(self.root) def transform(self): diff --git a/python/paddle/jit/dy2static/decorator_transformer.py b/python/paddle/jit/dy2static/decorator_transformer.py index f7391f301db..a61c25dd608 100644 --- a/python/paddle/jit/dy2static/decorator_transformer.py +++ b/python/paddle/jit/dy2static/decorator_transformer.py @@ -19,7 +19,6 @@ import warnings from paddle.utils import gast from .base_transformer import BaseTransformer -from .static_analysis import AstNodeWrapper from .utils import RE_PYMODULE, RE_PYNAME, ast_to_source_code __all__ = [] @@ -40,12 +39,8 @@ class DecoratorTransformer(BaseTransformer): Transform decorators. """ - def __init__(self, wrapper_root): - assert isinstance(wrapper_root, AstNodeWrapper), ( - "Type of input node should be AstNodeWrapper, but received %s ." - % type(wrapper_root) - ) - self.root = wrapper_root.node + def __init__(self, root): + self.root = root self.ancestor_nodes = [] diff --git a/python/paddle/jit/dy2static/early_return_transformer.py b/python/paddle/jit/dy2static/early_return_transformer.py index fd2b7865305..4613f2b6ecb 100644 --- a/python/paddle/jit/dy2static/early_return_transformer.py +++ b/python/paddle/jit/dy2static/early_return_transformer.py @@ -15,7 +15,6 @@ from paddle.utils import gast from .base_transformer import BaseTransformer -from .static_analysis import AstNodeWrapper __all__ = [] @@ -25,12 +24,8 @@ class EarlyReturnTransformer(BaseTransformer): Transform if/else return statement of Dygraph into Static Graph. """ - def __init__(self, wrapper_root): - assert isinstance(wrapper_root, AstNodeWrapper), ( - "Type of input node should be AstNodeWrapper, but received %s ." - % type(wrapper_root) - ) - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): """ diff --git a/python/paddle/jit/dy2static/ifelse_transformer.py b/python/paddle/jit/dy2static/ifelse_transformer.py index ba64b246b35..0986bc1933d 100644 --- a/python/paddle/jit/dy2static/ifelse_transformer.py +++ b/python/paddle/jit/dy2static/ifelse_transformer.py @@ -16,7 +16,6 @@ import copy from collections import defaultdict from paddle.fluid import unique_name -from paddle.jit.dy2static.static_analysis import AstNodeWrapper from paddle.jit.dy2static.utils import ( FOR_ITER_INDEX_PREFIX, FOR_ITER_ITERATOR_PREFIX, @@ -57,12 +56,8 @@ class IfElseTransformer(BaseTransformer): Transform if/else statement of Dygraph into Static Graph. """ - def __init__(self, wrapper_root): - assert isinstance(wrapper_root, AstNodeWrapper), ( - "Type of input node should be AstNodeWrapper, but received %s ." - % type(wrapper_root) - ) - self.root = wrapper_root.node + def __init__(self, root): + self.root = root FunctionNameLivenessAnalysis( self.root ) # name analysis of current ast tree. diff --git a/python/paddle/jit/dy2static/logical_transformer.py b/python/paddle/jit/dy2static/logical_transformer.py index a31ddcd44e9..90002c6e4bd 100644 --- a/python/paddle/jit/dy2static/logical_transformer.py +++ b/python/paddle/jit/dy2static/logical_transformer.py @@ -48,9 +48,8 @@ class LogicalTransformer(BaseTransformer): a = _jst.And(lambda:x>1, lambda:y<1) """ - def __init__(self, wrapper_root): - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): return self.visit(self.root) diff --git a/python/paddle/jit/dy2static/loop_transformer.py b/python/paddle/jit/dy2static/loop_transformer.py index cabb1d41c94..043d5be4b76 100644 --- a/python/paddle/jit/dy2static/loop_transformer.py +++ b/python/paddle/jit/dy2static/loop_transformer.py @@ -24,7 +24,7 @@ from .base_transformer import ( ForNodeVisitor, ) from .ifelse_transformer import ARGS_NAME -from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from .static_analysis import NodeVarType, StaticAnalysisVisitor from .utils import ( FOR_BODY_PREFIX, FOR_CONDITION_PREFIX, @@ -507,16 +507,12 @@ class LoopTransformer(BaseTransformer): This class transforms python while/for statement into Static Graph Ast """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root FunctionNameLivenessAnalysis(self.root) def transform(self): - ForLoopTuplePreTransformer(self.wrapper_root).transform() + ForLoopTuplePreTransformer(self.root).transform() self.visit(self.root) def visit_While(self, node): diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 38801a36606..44d16ae5b7c 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -132,17 +132,17 @@ class FunctionCache: # but actually they are methods in different classes. # Maybe use (__class__, source_code) as key if source_code in self._code_to_ast_caches: - root_wrapper = self._code_to_ast_caches[source_code] + root = self._code_to_ast_caches[source_code] else: root = gast.parse(source_code) root = attach_origin_info(root, func) - root_wrapper = self._dygraph_to_static.get_static_ast(root) - self._code_to_ast_caches[source_code] = root_wrapper + root = self._dygraph_to_static.get_static_ast(root) + self._code_to_ast_caches[source_code] = root # Get static function from AST - static_func, file_name = ast_to_func(root_wrapper.node, func) + static_func, file_name = ast_to_func(root, func) - create_and_update_origin_info_map(root_wrapper.node, static_func) + create_and_update_origin_info_map(root, static_func) return static_func def exist(self, func): @@ -1680,10 +1680,10 @@ class ProgramTranslator: # Transform AST dygraph_to_static = DygraphToStaticAst() - root_wrapper = dygraph_to_static.get_static_ast(root) + root = dygraph_to_static.get_static_ast(root) # Get source_code - source_code = ast_to_source_code(root_wrapper.node) + source_code = ast_to_source_code(root) return source_code def get_program_cache(self): diff --git a/python/paddle/jit/dy2static/return_transformer.py b/python/paddle/jit/dy2static/return_transformer.py index 48d48f949dc..e16608fc446 100644 --- a/python/paddle/jit/dy2static/return_transformer.py +++ b/python/paddle/jit/dy2static/return_transformer.py @@ -130,9 +130,8 @@ class ReturnTransformer(BaseTransformer): SingleReturnTransformer don't care the nested function def. """ - def __init__(self, wrapper_root): - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root pre_transformer = ReplaceReturnNoneTransformer(self.root) pre_transformer.transform() diff --git a/python/paddle/jit/dy2static/tensor_shape_transformer.py b/python/paddle/jit/dy2static/tensor_shape_transformer.py index 6efff3b0561..13b81608f79 100644 --- a/python/paddle/jit/dy2static/tensor_shape_transformer.py +++ b/python/paddle/jit/dy2static/tensor_shape_transformer.py @@ -15,7 +15,6 @@ from paddle.utils import gast from .base_transformer import BaseTransformer -from .static_analysis import AstNodeWrapper from .utils import ast_to_source_code __all__ = [] @@ -27,12 +26,8 @@ class TensorShapeTransformer(BaseTransformer): All 'xxx.shape' will be converted int '_jst.Shape(x)'. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): self.visit(self.root) diff --git a/python/paddle/jit/dy2static/tensorhook_transformer.py b/python/paddle/jit/dy2static/tensorhook_transformer.py index ff0c4b67ca1..9add0c9d825 100644 --- a/python/paddle/jit/dy2static/tensorhook_transformer.py +++ b/python/paddle/jit/dy2static/tensorhook_transformer.py @@ -20,10 +20,10 @@ from .base_transformer import BaseTransformer class RegisterHookTransformer(BaseTransformer): - def __init__(self, wrapper_root): + def __init__(self, root): self.register_hook_pos_map = collections.defaultdict(list) self.assignment_pos_map = collections.defaultdict(list) - self.root = wrapper_root.node + self.root = root def transform(self): """ diff --git a/python/paddle/jit/dy2static/typehint_transformer.py b/python/paddle/jit/dy2static/typehint_transformer.py index 68a395973ca..da2f625ce5a 100644 --- a/python/paddle/jit/dy2static/typehint_transformer.py +++ b/python/paddle/jit/dy2static/typehint_transformer.py @@ -13,8 +13,6 @@ # limitations under the License. -from paddle.jit.dy2static.static_analysis import AstNodeWrapper - from .base_transformer import BaseTransformer __all__ = [] @@ -26,12 +24,8 @@ class TypeHintTransformer(BaseTransformer): Please put it behind other transformers because other transformer may relay on typehints. """ - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node + def __init__(self, root): + self.root = root def transform(self): self.visit(self.root) diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index 90436f6d72b..2ae8aa80398 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -96,7 +96,7 @@ class TestOriginInfo(unittest.TestCase): dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func) # step2 - transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast).node + transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast) # step3 self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) -- GitLab