diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index e9280954bf741935603ce5234af1010f6bbb4ce2..8d24060a88e20545c606b93c9c030da2cdb04586 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -19,7 +19,7 @@ import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list -from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser +from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node __all__ = ['BreakContinueTransformer'] @@ -67,7 +67,7 @@ class ForToWhileTransformer(gast.NodeTransformer): node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes" # 1. parse current gast.For node - current_for_node_parser = ForNodeParser(node) + current_for_node_parser = ForNodeVisitor(node) stmts_tuple = current_for_node_parser.parse() if stmts_tuple is None: return [node] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 6c8b27625b5a450ccf194760a91b0cc1982978cf..d4d1ff6ba2db46298270b3dba36748bd6f92d3e8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform -from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser +from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node @@ -374,7 +374,7 @@ class LoopTransformer(gast.NodeTransformer): # 2). cond_stmt: node, condition node to judge whether continue loop # 3). body_stmts: list[node], updated loop body, sometimes we should change # the original statement in body, not just append new statement - current_for_node_parser = ForNodeParser(node) + current_for_node_parser = ForNodeVisitor(node) stmts_tuple = current_for_node_parser.parse() if stmts_tuple is None: return [node] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 0372dab799bfb6fda2e6b1ffe0397d1250522acf..bad09a160d24bc7ca6d0166b8d0820122400b5eb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -628,7 +628,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): class NameNodeReplaceTransformer(gast.NodeTransformer): """ - This class transform specfice gast.Name node to replace node + This class replaces specified gast.Name node by replace_node. """ def __init__(self, root_node, target_name, replace_node): @@ -644,9 +644,9 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): return node -class ForNodeParser(object): +class ForNodeVisitor(object): """ - This class parse python for statement, get transformed 3 statement components of for node + This class parses python for statement, get transformed 3 statement components of for node three key statements: 1). init_stmts: list[node], prepare nodes of for loop, may not only one 2). cond_stmt: node, condition node to judge whether continue loop @@ -664,7 +664,7 @@ class ForNodeParser(object): def __init__(self, for_node): assert isinstance( for_node, gast.For - ), "Input node for the initialization of ForNodeParser is not gast.For node." + ), "Input node for the initialization of ForNodeVisitor is not gast.For node." # 1. original for node self.node = for_node @@ -806,9 +806,7 @@ class ForNodeParser(object): ], value=self.iter_args[0]) else: - # TODO: slice bug, only support int32 index - index_init_node = get_constant_variable_node( - self.iter_idx_name, 0, dtype='int32') + index_init_node = get_constant_variable_node(self.iter_idx_name, 0) return index_init_node def _build_var_shape_assign_node(self):