未验证 提交 b2c1247c 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Polish visit function in transformer (#44083)

* Polish visit function in transformer

* Call generic_visit first in visit_While/For

* Remove comments

* Polish utils.py, move some transformer to base_transformer
上级 9900b42b
...@@ -18,10 +18,10 @@ from paddle.utils import gast ...@@ -18,10 +18,10 @@ from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
__all__ = ['BreakContinueTransformer'] __all__ = ['BreakContinueTransformer']
......
...@@ -21,8 +21,8 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe ...@@ -21,8 +21,8 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import SplitAssignTransformer
class ListTransformer(BaseTransformer): class ListTransformer(BaseTransformer):
......
...@@ -25,14 +25,14 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi ...@@ -25,14 +25,14 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -489,14 +489,15 @@ class LoopTransformer(BaseTransformer): ...@@ -489,14 +489,15 @@ class LoopTransformer(BaseTransformer):
self.name_visitor = NameVisitor(self.root) self.name_visitor = NameVisitor(self.root)
self.visit(self.root) self.visit(self.root)
def visit(self, node): def visit_While(self, node):
self.generic_visit(node) self.generic_visit(node)
# All parent nodes that may contain gast.While/gast.For new_stmts = self.get_while_stmt_nodes(node)
if hasattr(node, 'body'): return new_stmts
self.replace_stmt_list(node.body)
if hasattr(node, 'orelse'): def visit_For(self, node):
self.replace_stmt_list(node.orelse) self.generic_visit(node)
return node new_stmts = self.get_for_stmt_nodes(node)
return new_stmts
def replace_stmt_list(self, body_list): def replace_stmt_list(self, body_list):
if not isinstance(body_list, list): if not isinstance(body_list, list):
......
...@@ -20,16 +20,13 @@ import inspect ...@@ -20,16 +20,13 @@ import inspect
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
try: try:
from collections.abc import Sequence from collections.abc import Sequence
except: except:
from collections import Sequence from collections import Sequence
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node."
ORIGI_INFO_MAP = "Original information map of source code."
class Location(object): class Location(object):
""" """
......
...@@ -188,9 +188,7 @@ class ReturnTransformer(BaseTransformer): ...@@ -188,9 +188,7 @@ class ReturnTransformer(BaseTransformer):
Self-defined visit for appending ancestor Self-defined visit for appending ancestor
""" """
self.ancestor_nodes.append(node) self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__ ret = super(ReturnTransformer, self).visit(node)
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop() self.ancestor_nodes.pop()
return ret return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册