未验证 提交 d7ad0e42 编写于 作者: H haozi 提交者: GitHub

dy2static: Remove the dependency of AstNodeWrapper (#54237)

* [IR] delete_wrapper

* [IR] modify the ast_transformer core logic

* [fix] modify the test_origin_info
上级 9f76d050
......@@ -34,6 +34,6 @@ from .convert_operators import convert_shape_compare # noqa: F401
from .assert_transformer import AssertTransformer
from .ast_transformer import DygraphToStaticAst
from .program_translator import convert_to_static
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .static_analysis import NodeVarType, StaticAnalysisVisitor
__all__ = []
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.utils import gast
......@@ -26,12 +25,8 @@ class AssertTransformer(BaseTransformer):
A class transforms python assert to convert_assert.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -36,7 +36,6 @@ from .ifelse_transformer import IfElseTransformer
from .logical_transformer import LogicalTransformer
from .loop_transformer import LoopTransformer
from .return_transformer import ReturnTransformer
from .static_analysis import StaticAnalysisVisitor
from .tensor_shape_transformer import TensorShapeTransformer
from .tensorhook_transformer import RegisterHookTransformer
from .typehint_transformer import TypeHintTransformer
......@@ -69,28 +68,25 @@ class DygraphToStaticAst(BaseTransformer):
self.translator_logger = logging_utils.TranslatorLogger()
def get_static_ast(self, root):
# save root for some analysis may need global AST
self.root = root
self.static_analysis_visitor = StaticAnalysisVisitor(root)
self.static_analysis_root = (
self.static_analysis_visitor.get_node_wrapper_root()
)
self.decorate_func_name = None
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
def _apply(self, transformer, node_wrapper, log_level):
transformer(node_wrapper).transform()
# inplace transfer
self.transfer_from_node_type(self.root)
return self.root
def _apply(self, transformer, node, log_level):
transformer(node).transform()
self.translator_logger.log_transformed_code(
log_level, self.root, transformer.__name__
)
def transfer_from_node_type(self, node_wrapper):
def transfer_from_node_type(self, node):
self.translator_logger.log(
1, f"Source code: \n{ast_to_source_code(self.root)}"
)
# Generic transformation
self.visit(node_wrapper.node)
self.visit(node)
transformers = [
RegisterHookTransformer,
......@@ -114,7 +110,7 @@ class DygraphToStaticAst(BaseTransformer):
apply_optimization(transformers)
for index, transformer in enumerate(transformers):
self._apply(transformer, node_wrapper, log_level=index + 1)
self._apply(transformer, node, log_level=index + 1)
self.translator_logger.log_transformed_code(
logging_utils.LOG_AllTransformer, self.root, "All Transformers"
......
......@@ -139,9 +139,8 @@ class ForLoopTuplePreTransformer(BaseTransformer):
>>> body
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -18,7 +18,6 @@ from paddle.utils import gast
from . import utils
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
__all__ = []
......@@ -28,13 +27,8 @@ class BasicApiTransformer(BaseTransformer):
Class to transform basic API from dygraph to static graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
self.class_node_dict = {}
def transform(self):
......@@ -43,7 +37,7 @@ class BasicApiTransformer(BaseTransformer):
attribute_transformer = AttributeJstTransformer(self.root)
attribute_transformer.transform()
self.visit(self.root)
return self.wrapper_root
return self.root
def visit_Assign(self, node):
if self._update_class_node_dict(node):
......@@ -138,13 +132,8 @@ class NameloadJstTransformer(BaseTransformer):
NOTE: we only deal with ctx=Load() case.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -111,11 +111,10 @@ class BreakContinueTransformer(BaseNodeVisitor):
In general, we recommend to inheriting NodeTransformer to modify node!
"""
def __init__(self, wrapper_root):
def __init__(self, root):
super().__init__()
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.root = root
def transform(self):
self.visit(self.root)
......@@ -336,11 +335,10 @@ class BreakTransformOptimizer(BaseNodeVisitor):
usually brings very heavy overhead.
"""
def __init__(self, wrapper_root):
def __init__(self, root):
super().__init__()
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api
from paddle.utils import gast
......@@ -29,12 +28,8 @@ class CallTransformer(BaseTransformer):
This class transforms function calls into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CallTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def _no_need_convert_call(self, node):
"""
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.utils import gast
......@@ -26,15 +25,12 @@ class CastTransformer(BaseTransformer):
This class transforms type casting into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of CastTransformer."
self._root = wrapper_root.node
def __init__(self, root):
self.root = root
self._castable_type = {'bool', 'int', 'float'}
def transform(self):
self.visit(self._root)
self.visit(self.root)
def visit_Call(self, node):
self.generic_visit(node)
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import FunctionNameLivenessAnalysis
from .variable_trans_func import create_undefined_var
......@@ -23,12 +22,8 @@ __all__ = []
class CreateVariableTransformer(BaseTransformer):
""" """
def __init__(self, wrapper_root):
assert isinstance(wrapper_root, AstNodeWrapper), (
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
FunctionNameLivenessAnalysis(self.root)
def transform(self):
......
......@@ -19,7 +19,6 @@ import warnings
from paddle.utils import gast
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import RE_PYMODULE, RE_PYNAME, ast_to_source_code
__all__ = []
......@@ -40,12 +39,8 @@ class DecoratorTransformer(BaseTransformer):
Transform decorators.
"""
def __init__(self, wrapper_root):
assert isinstance(wrapper_root, AstNodeWrapper), (
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
self.ancestor_nodes = []
......
......@@ -15,7 +15,6 @@
from paddle.utils import gast
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
__all__ = []
......@@ -25,12 +24,8 @@ class EarlyReturnTransformer(BaseTransformer):
Transform if/else return statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(wrapper_root, AstNodeWrapper), (
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
"""
......
......@@ -16,7 +16,6 @@ import copy
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from paddle.jit.dy2static.utils import (
FOR_ITER_INDEX_PREFIX,
FOR_ITER_ITERATOR_PREFIX,
......@@ -57,12 +56,8 @@ class IfElseTransformer(BaseTransformer):
Transform if/else statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(wrapper_root, AstNodeWrapper), (
"Type of input node should be AstNodeWrapper, but received %s ."
% type(wrapper_root)
)
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
FunctionNameLivenessAnalysis(
self.root
) # name analysis of current ast tree.
......
......@@ -48,9 +48,8 @@ class LogicalTransformer(BaseTransformer):
a = _jst.And(lambda:x>1, lambda:y<1)
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
return self.visit(self.root)
......
......@@ -24,7 +24,7 @@ from .base_transformer import (
ForNodeVisitor,
)
from .ifelse_transformer import ARGS_NAME
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .static_analysis import NodeVarType, StaticAnalysisVisitor
from .utils import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
......@@ -507,16 +507,12 @@ class LoopTransformer(BaseTransformer):
This class transforms python while/for statement into Static Graph Ast
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
FunctionNameLivenessAnalysis(self.root)
def transform(self):
ForLoopTuplePreTransformer(self.wrapper_root).transform()
ForLoopTuplePreTransformer(self.root).transform()
self.visit(self.root)
def visit_While(self, node):
......
......@@ -132,17 +132,17 @@ class FunctionCache:
# but actually they are methods in different classes.
# Maybe use (__class__, source_code) as key
if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code]
root = self._code_to_ast_caches[source_code]
else:
root = gast.parse(source_code)
root = attach_origin_info(root, func)
root_wrapper = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper
root = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root
# Get static function from AST
static_func, file_name = ast_to_func(root_wrapper.node, func)
static_func, file_name = ast_to_func(root, func)
create_and_update_origin_info_map(root_wrapper.node, static_func)
create_and_update_origin_info_map(root, static_func)
return static_func
def exist(self, func):
......@@ -1680,10 +1680,10 @@ class ProgramTranslator:
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
root = dygraph_to_static.get_static_ast(root)
# Get source_code
source_code = ast_to_source_code(root_wrapper.node)
source_code = ast_to_source_code(root)
return source_code
def get_program_cache(self):
......
......@@ -130,9 +130,8 @@ class ReturnTransformer(BaseTransformer):
SingleReturnTransformer don't care the nested function def.
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform()
......
......@@ -15,7 +15,6 @@
from paddle.utils import gast
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
from .utils import ast_to_source_code
__all__ = []
......@@ -27,12 +26,8 @@ class TensorShapeTransformer(BaseTransformer):
All 'xxx.shape' will be converted int '_jst.Shape(x)'.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -20,10 +20,10 @@ from .base_transformer import BaseTransformer
class RegisterHookTransformer(BaseTransformer):
def __init__(self, wrapper_root):
def __init__(self, root):
self.register_hook_pos_map = collections.defaultdict(list)
self.assignment_pos_map = collections.defaultdict(list)
self.root = wrapper_root.node
self.root = root
def transform(self):
"""
......
......@@ -13,8 +13,6 @@
# limitations under the License.
from paddle.jit.dy2static.static_analysis import AstNodeWrapper
from .base_transformer import BaseTransformer
__all__ = []
......@@ -26,12 +24,8 @@ class TypeHintTransformer(BaseTransformer):
Please put it behind other transformers because other transformer may relay on typehints.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def __init__(self, root):
self.root = root
def transform(self):
self.visit(self.root)
......
......@@ -96,7 +96,7 @@ class TestOriginInfo(unittest.TestCase):
dygraph_ast = attach_origin_info(dygraph_ast, self.dygraph_func)
# step2
transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast).node
transformed_ast = DygraphToStaticAst().get_static_ast(dygraph_ast)
# step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册