未验证 提交 d7ad0e42 编写于 作者: H haozi 提交者: GitHub

dy2static: Remove the dependency of AstNodeWrapper (#54237)

* [IR] delete_wrapper

* [IR] modify the ast_transformer core logic

* [fix] modify the test_origin_info
上级 9f76d050
...@@ -34,6 +34,6 @@ from .convert_operators import convert_shape_compare # noqa: F401 ...@@ -34,6 +34,6 @@ from .convert_operators import convert_shape_compare # noqa: F401
from .assert_transformer import AssertTransformer from .assert_transformer import AssertTransformer
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
from .program_translator import convert_to_static from .program_translator import convert_to_static
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from .static_analysis import NodeVarType, StaticAnalysisVisitor
__all__ = [] __all__ = []
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.utils import gast from paddle.utils import gast
...@@ -26,12 +25,8 @@ class AssertTransformer(BaseTransformer): ...@@ -26,12 +25,8 @@ class AssertTransformer(BaseTransformer):
A class transforms python assert to convert_assert. A class transforms python assert to convert_assert.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -36,7 +36,6 @@ from .ifelse_transformer import IfElseTransformer ...@@ -36,7 +36,6 @@ from .ifelse_transformer import IfElseTransformer
from .logical_transformer import LogicalTransformer from .logical_transformer import LogicalTransformer
from .loop_transformer import LoopTransformer from .loop_transformer import LoopTransformer
from .return_transformer import ReturnTransformer from .return_transformer import ReturnTransformer
from .static_analysis import StaticAnalysisVisitor
from .tensor_shape_transformer import TensorShapeTransformer from .tensor_shape_transformer import TensorShapeTransformer
from .tensorhook_transformer import RegisterHookTransformer from .tensorhook_transformer import RegisterHookTransformer
from .typehint_transformer import TypeHintTransformer from .typehint_transformer import TypeHintTransformer
...@@ -69,28 +68,25 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -69,28 +68,25 @@ class DygraphToStaticAst(BaseTransformer):
self.translator_logger = logging_utils.TranslatorLogger() self.translator_logger = logging_utils.TranslatorLogger()
def get_static_ast(self, root): def get_static_ast(self, root):
# save root for some analysis may need global AST
self.root = root self.root = root
self.static_analysis_visitor = StaticAnalysisVisitor(root)
self.static_analysis_root = (
self.static_analysis_visitor.get_node_wrapper_root()
)
self.decorate_func_name = None self.decorate_func_name = None
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
def _apply(self, transformer, node_wrapper, log_level): # inplace transfer
transformer(node_wrapper).transform() self.transfer_from_node_type(self.root)
return self.root
def _apply(self, transformer, node, log_level):
transformer(node).transform()
self.translator_logger.log_transformed_code( self.translator_logger.log_transformed_code(
log_level, self.root, transformer.__name__ log_level, self.root, transformer.__name__
) )
def transfer_from_node_type(self, node_wrapper): def transfer_from_node_type(self, node):
self.translator_logger.log( self.translator_logger.log(
1, f"Source code: \n{ast_to_source_code(self.root)}" 1, f"Source code: \n{ast_to_source_code(self.root)}"
) )
# Generic transformation # Generic transformation
self.visit(node_wrapper.node) self.visit(node)
transformers = [ transformers = [
RegisterHookTransformer, RegisterHookTransformer,
...@@ -114,7 +110,7 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -114,7 +110,7 @@ class DygraphToStaticAst(BaseTransformer):
apply_optimization(transformers) apply_optimization(transformers)
for index, transformer in enumerate(transformers): for index, transformer in enumerate(transformers):
self._apply(transformer, node_wrapper, log_level=index + 1) self._apply(transformer, node, log_level=index + 1)
self.translator_logger.log_transformed_code( self.translator_logger.log_transformed_code(
logging_utils.LOG_AllTransformer, self.root, "All Transformers" logging_utils.LOG_AllTransformer, self.root, "All Transformers"
......
...@@ -139,9 +139,8 @@ class ForLoopTuplePreTransformer(BaseTransformer): ...@@ -139,9 +139,8 @@ class ForLoopTuplePreTransformer(BaseTransformer):
>>> body >>> body
""" """
def __init__(self, wrapper_root): def __init__(self, root):
self.wrapper_root = wrapper_root self.root = root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -18,7 +18,6 @@ from paddle.utils import gast ...@@ -18,7 +18,6 @@ from paddle.utils import gast
from . import utils from . import utils
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
__all__ = [] __all__ = []
...@@ -28,13 +27,8 @@ class BasicApiTransformer(BaseTransformer): ...@@ -28,13 +27,8 @@ class BasicApiTransformer(BaseTransformer):
Class to transform basic API from dygraph to static graph. Class to transform basic API from dygraph to static graph.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.class_node_dict = {} self.class_node_dict = {}
def transform(self): def transform(self):
...@@ -43,7 +37,7 @@ class BasicApiTransformer(BaseTransformer): ...@@ -43,7 +37,7 @@ class BasicApiTransformer(BaseTransformer):
attribute_transformer = AttributeJstTransformer(self.root) attribute_transformer = AttributeJstTransformer(self.root)
attribute_transformer.transform() attribute_transformer.transform()
self.visit(self.root) self.visit(self.root)
return self.wrapper_root return self.root
def visit_Assign(self, node): def visit_Assign(self, node):
if self._update_class_node_dict(node): if self._update_class_node_dict(node):
...@@ -138,13 +132,8 @@ class NameloadJstTransformer(BaseTransformer): ...@@ -138,13 +132,8 @@ class NameloadJstTransformer(BaseTransformer):
NOTE: we only deal with ctx=Load() case. NOTE: we only deal with ctx=Load() case.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -111,11 +111,10 @@ class BreakContinueTransformer(BaseNodeVisitor): ...@@ -111,11 +111,10 @@ class BreakContinueTransformer(BaseNodeVisitor):
In general, we recommend to inheriting NodeTransformer to modify node! In general, we recommend to inheriting NodeTransformer to modify node!
""" """
def __init__(self, wrapper_root): def __init__(self, root):
super().__init__() super().__init__()
self.wrapper_root = wrapper_root self.root = root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
...@@ -336,11 +335,10 @@ class BreakTransformOptimizer(BaseNodeVisitor): ...@@ -336,11 +335,10 @@ class BreakTransformOptimizer(BaseNodeVisitor):
usually brings very heavy overhead. usually brings very heavy overhead.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
super().__init__() super().__init__()
self.wrapper_root = wrapper_root self.root = root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api
from paddle.utils import gast from paddle.utils import gast
...@@ -29,12 +28,8 @@ class CallTransformer(BaseTransformer): ...@@ -29,12 +28,8 @@ class CallTransformer(BaseTransformer):
This class transforms function calls into Static Graph Ast. This class transforms function calls into Static Graph Ast.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CallTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def _no_need_convert_call(self, node): def _no_need_convert_call(self, node):
""" """
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.utils import gast from paddle.utils import gast
...@@ -26,15 +25,12 @@ class CastTransformer(BaseTransformer): ...@@ -26,15 +25,12 @@ class CastTransformer(BaseTransformer):
This class transforms type casting into Static Graph Ast. This class transforms type casting into Static Graph Ast.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CastTransformer."
self._root = wrapper_root.node
self._castable_type = {'bool', 'int', 'float'} self._castable_type = {'bool', 'int', 'float'}
def transform(self): def transform(self):
self.visit(self._root) self.visit(self.root)
def visit_Call(self, node): def visit_Call(self, node):
self.generic_visit(node) self.generic_visit(node)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import FunctionNameLivenessAnalysis from .utils import FunctionNameLivenessAnalysis
from .variable_trans_func import create_undefined_var from .variable_trans_func import create_undefined_var
...@@ -23,12 +22,8 @@ __all__ = [] ...@@ -23,12 +22,8 @@ __all__ = []
class CreateVariableTransformer(BaseTransformer): class CreateVariableTransformer(BaseTransformer):
""" """ """ """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance(wrapper_root, AstNodeWrapper), ( self.root = root
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root) FunctionNameLivenessAnalysis(self.root)
def transform(self): def transform(self):
......
...@@ -19,7 +19,6 @@ import warnings ...@@ -19,7 +19,6 @@ import warnings
from paddle.utils import gast from paddle.utils import gast
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import RE_PYMODULE, RE_PYNAME, ast_to_source_code from .utils import RE_PYMODULE, RE_PYNAME, ast_to_source_code
__all__ = [] __all__ = []
...@@ -40,12 +39,8 @@ class DecoratorTransformer(BaseTransformer): ...@@ -40,12 +39,8 @@ class DecoratorTransformer(BaseTransformer):
Transform decorators. Transform decorators.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance(wrapper_root, AstNodeWrapper), ( self.root = root
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
self.ancestor_nodes = [] self.ancestor_nodes = []
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from paddle.utils import gast from paddle.utils import gast
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
__all__ = [] __all__ = []
...@@ -25,12 +24,8 @@ class EarlyReturnTransformer(BaseTransformer): ...@@ -25,12 +24,8 @@ class EarlyReturnTransformer(BaseTransformer):
Transform if/else return statement of Dygraph into Static Graph. Transform if/else return statement of Dygraph into Static Graph.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance(wrapper_root, AstNodeWrapper), ( self.root = root
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
def transform(self): def transform(self):
""" """
......
...@@ -16,7 +16,6 @@ import copy ...@@ -16,7 +16,6 @@ import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ( from paddle.jit.dy2static.utils import (
FOR_ITER_INDEX_PREFIX, FOR_ITER_INDEX_PREFIX,
FOR_ITER_ITERATOR_PREFIX, FOR_ITER_ITERATOR_PREFIX,
...@@ -57,12 +56,8 @@ class IfElseTransformer(BaseTransformer): ...@@ -57,12 +56,8 @@ class IfElseTransformer(BaseTransformer):
Transform if/else statement of Dygraph into Static Graph. Transform if/else statement of Dygraph into Static Graph.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance(wrapper_root, AstNodeWrapper), ( self.root = root
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
FunctionNameLivenessAnalysis( FunctionNameLivenessAnalysis(
self.root self.root
) # name analysis of current ast tree. ) # name analysis of current ast tree.
......
...@@ -48,9 +48,8 @@ class LogicalTransformer(BaseTransformer): ...@@ -48,9 +48,8 @@ class LogicalTransformer(BaseTransformer):
a = _jst.And(lambda:x>1, lambda:y<1) a = _jst.And(lambda:x>1, lambda:y<1)
""" """
def __init__(self, wrapper_root): def __init__(self, root):
self.wrapper_root = wrapper_root self.root = root
self.root = wrapper_root.node
def transform(self): def transform(self):
return self.visit(self.root) return self.visit(self.root)
......
...@@ -24,7 +24,7 @@ from .base_transformer import ( ...@@ -24,7 +24,7 @@ from .base_transformer import (
ForNodeVisitor, ForNodeVisitor,
) )
from .ifelse_transformer import ARGS_NAME from .ifelse_transformer import ARGS_NAME
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from .static_analysis import NodeVarType, StaticAnalysisVisitor
from .utils import ( from .utils import (
FOR_BODY_PREFIX, FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX, FOR_CONDITION_PREFIX,
...@@ -507,16 +507,12 @@ class LoopTransformer(BaseTransformer): ...@@ -507,16 +507,12 @@ class LoopTransformer(BaseTransformer):
This class transforms python while/for statement into Static Graph Ast This class transforms python while/for statement into Static Graph Ast
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(self.root) FunctionNameLivenessAnalysis(self.root)
def transform(self): def transform(self):
ForLoopTuplePreTransformer(self.wrapper_root).transform() ForLoopTuplePreTransformer(self.root).transform()
self.visit(self.root) self.visit(self.root)
def visit_While(self, node): def visit_While(self, node):
......
...@@ -132,17 +132,17 @@ class FunctionCache: ...@@ -132,17 +132,17 @@ class FunctionCache:
# but actually they are methods in different classes. # but actually they are methods in different classes.
# Maybe use (__class__, source_code) as key # Maybe use (__class__, source_code) as key
if source_code in self._code_to_ast_caches: if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code] root = self._code_to_ast_caches[source_code]
else: else:
root = gast.parse(source_code) root = gast.parse(source_code)
root = attach_origin_info(root, func) root = attach_origin_info(root, func)
root_wrapper = self._dygraph_to_static.get_static_ast(root) root = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper self._code_to_ast_caches[source_code] = root
# Get static function from AST # Get static function from AST
static_func, file_name = ast_to_func(root_wrapper.node, func) static_func, file_name = ast_to_func(root, func)
create_and_update_origin_info_map(root_wrapper.node, static_func) create_and_update_origin_info_map(root, static_func)
return static_func return static_func
def exist(self, func): def exist(self, func):
...@@ -1680,10 +1680,10 @@ class ProgramTranslator: ...@@ -1680,10 +1680,10 @@ class ProgramTranslator:
# Transform AST # Transform AST
dygraph_to_static = DygraphToStaticAst() dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root) root = dygraph_to_static.get_static_ast(root)
# Get source_code # Get source_code
source_code = ast_to_source_code(root_wrapper.node) source_code = ast_to_source_code(root)
return source_code return source_code
def get_program_cache(self): def get_program_cache(self):
......
...@@ -130,9 +130,8 @@ class ReturnTransformer(BaseTransformer): ...@@ -130,9 +130,8 @@ class ReturnTransformer(BaseTransformer):
SingleReturnTransformer don't care the nested function def. SingleReturnTransformer don't care the nested function def.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
self.wrapper_root = wrapper_root self.root = root
self.root = wrapper_root.node
pre_transformer = ReplaceReturnNoneTransformer(self.root) pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform() pre_transformer.transform()
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from paddle.utils import gast from paddle.utils import gast
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import ast_to_source_code from .utils import ast_to_source_code
__all__ = [] __all__ = []
...@@ -27,12 +26,8 @@ class TensorShapeTransformer(BaseTransformer): ...@@ -27,12 +26,8 @@ class TensorShapeTransformer(BaseTransformer):
All 'xxx.shape' will be converted int '_jst.Shape(x)'. All 'xxx.shape' will be converted int '_jst.Shape(x)'.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -20,10 +20,10 @@ from .base_transformer import BaseTransformer ...@@ -20,10 +20,10 @@ from .base_transformer import BaseTransformer
class RegisterHookTransformer(BaseTransformer): class RegisterHookTransformer(BaseTransformer):
def __init__(self, wrapper_root): def __init__(self, root):
self.register_hook_pos_map = collections.defaultdict(list) self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list) self.assignment_pos_map = collections.defaultdict(list)
self.root = wrapper_root.node self.root = root
def transform(self): def transform(self):
""" """
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
__all__ = [] __all__ = []
...@@ -26,12 +24,8 @@ class TypeHintTransformer(BaseTransformer): ...@@ -26,12 +24,8 @@ class TypeHintTransformer(BaseTransformer):
Please put it behind other transformers because other transformer may relay on typehints. Please put it behind other transformers because other transformer may relay on typehints.
""" """
def __init__(self, wrapper_root): def __init__(self, root):
assert isinstance( self.root = root
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
......
...@@ -96,7 +96,7 @@ class TestOriginInfo(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestOriginInfo(unittest.TestCase):
dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func) dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func)
# step2 # step2
transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast).node transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast)
# step3 # step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册