未验证 提交 d26f581c 编写于 作者: C Chen Weihang 提交者: GitHub

fix some detail problems, test=develop (#24614)

上级 6b464f96
......@@ -19,7 +19,7 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer']
......@@ -67,7 +67,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"
# 1. parse current gast.For node
current_for_node_parser = ForNodeParser(node)
current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
......
......@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
......@@ -374,7 +374,7 @@ class LoopTransformer(gast.NodeTransformer):
# 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement
current_for_node_parser = ForNodeParser(node)
current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
......
......@@ -628,7 +628,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
class NameNodeReplaceTransformer(gast.NodeTransformer):
"""
This class transform specfice gast.Name node to replace node
This class replaces specified gast.Name node by replace_node.
"""
def __init__(self, root_node, target_name, replace_node):
......@@ -644,9 +644,9 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
return node
class ForNodeParser(object):
class ForNodeVisitor(object):
"""
This class parse python for statement, get transformed 3 statement components of for node
This class parses python for statement, get transformed 3 statement components of for node
three key statements:
1). init_stmts: list[node], prepare nodes of for loop, may not only one
2). cond_stmt: node, condition node to judge whether continue loop
......@@ -664,7 +664,7 @@ class ForNodeParser(object):
def __init__(self, for_node):
assert isinstance(
for_node, gast.For
), "Input node for the initialization of ForNodeParser is not gast.For node."
), "Input node for the initialization of ForNodeVisitor is not gast.For node."
# 1. original for node
self.node = for_node
......@@ -806,9 +806,7 @@ class ForNodeParser(object):
],
value=self.iter_args[0])
else:
# TODO: slice bug, only support int32 index
index_init_node = get_constant_variable_node(
self.iter_idx_name, 0, dtype='int32')
index_init_node = get_constant_variable_node(self.iter_idx_name, 0)
return index_init_node
def _build_var_shape_assign_node(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册