未验证 提交 b2c1247c 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Polish visit function in transformer (#44083)

* Polish visit function in transformer

* Call generic_visit first in visit_While/For

* Remove comments

* Polish utils.py, move some transformer to base_transformer
上级 9900b42b
......@@ -18,10 +18,10 @@ from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
__all__ = ['BreakContinueTransformer']
......
......@@ -21,8 +21,8 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import SplitAssignTransformer
class ListTransformer(BaseTransformer):
......
......@@ -25,14 +25,14 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor
__all__ = ['LoopTransformer', 'NameVisitor']
......@@ -489,14 +489,15 @@ class LoopTransformer(BaseTransformer):
self.name_visitor = NameVisitor(self.root)
self.visit(self.root)
def visit(self, node):
def visit_While(self, node):
self.generic_visit(node)
# All parent nodes that may contain gast.While/gast.For
if hasattr(node, 'body'):
self.replace_stmt_list(node.body)
if hasattr(node, 'orelse'):
self.replace_stmt_list(node.orelse)
return node
new_stmts = self.get_while_stmt_nodes(node)
return new_stmts
def visit_For(self, node):
self.generic_visit(node)
new_stmts = self.get_for_stmt_nodes(node)
return new_stmts
def replace_stmt_list(self, body_list):
if not isinstance(body_list, list):
......
......@@ -20,16 +20,13 @@ import inspect
from paddle.utils import gast
from paddle.fluid import core
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO
from paddle.fluid.framework import Program
try:
from collections.abc import Sequence
except:
from collections import Sequence
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node."
ORIGI_INFO_MAP = "Original information map of source code."
class Location(object):
"""
......
......@@ -188,9 +188,7 @@ class ReturnTransformer(BaseTransformer):
Self-defined visit for appending ancestor
"""
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
ret = super(ReturnTransformer, self).visit(node)
self.ancestor_nodes.pop()
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册