未验证 提交 d26f581c 编写于 作者: C Chen Weihang 提交者: GitHub

fix some detail problems, test=develop (#24614)

上级 6b464f96
...@@ -19,7 +19,7 @@ import gast ...@@ -19,7 +19,7 @@ import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer'] __all__ = ['BreakContinueTransformer']
...@@ -67,7 +67,7 @@ class ForToWhileTransformer(gast.NodeTransformer): ...@@ -67,7 +67,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes" node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"
# 1. parse current gast.For node # 1. parse current gast.For node
current_for_node_parser = ForNodeParser(node) current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse() stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None: if stmts_tuple is None:
return [node] return [node]
......
...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node ...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
...@@ -374,7 +374,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -374,7 +374,7 @@ class LoopTransformer(gast.NodeTransformer):
# 2). cond_stmt: node, condition node to judge whether continue loop # 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change # 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement # the original statement in body, not just append new statement
current_for_node_parser = ForNodeParser(node) current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse() stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None: if stmts_tuple is None:
return [node] return [node]
......
...@@ -628,7 +628,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -628,7 +628,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
class NameNodeReplaceTransformer(gast.NodeTransformer): class NameNodeReplaceTransformer(gast.NodeTransformer):
""" """
This class transform specfice gast.Name node to replace node This class replaces specified gast.Name node by replace_node.
""" """
def __init__(self, root_node, target_name, replace_node): def __init__(self, root_node, target_name, replace_node):
...@@ -644,9 +644,9 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): ...@@ -644,9 +644,9 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
return node return node
class ForNodeParser(object): class ForNodeVisitor(object):
""" """
This class parse python for statement, get transformed 3 statement components of for node This class parses python for statement, get transformed 3 statement components of for node
three key statements: three key statements:
1). init_stmts: list[node], prepare nodes of for loop, may not only one 1). init_stmts: list[node], prepare nodes of for loop, may not only one
2). cond_stmt: node, condition node to judge whether continue loop 2). cond_stmt: node, condition node to judge whether continue loop
...@@ -664,7 +664,7 @@ class ForNodeParser(object): ...@@ -664,7 +664,7 @@ class ForNodeParser(object):
def __init__(self, for_node): def __init__(self, for_node):
assert isinstance( assert isinstance(
for_node, gast.For for_node, gast.For
), "Input node for the initialization of ForNodeParser is not gast.For node." ), "Input node for the initialization of ForNodeVisitor is not gast.For node."
# 1. original for node # 1. original for node
self.node = for_node self.node = for_node
...@@ -806,9 +806,7 @@ class ForNodeParser(object): ...@@ -806,9 +806,7 @@ class ForNodeParser(object):
], ],
value=self.iter_args[0]) value=self.iter_args[0])
else: else:
# TODO: slice bug, only support int32 index index_init_node = get_constant_variable_node(self.iter_idx_name, 0)
index_init_node = get_constant_variable_node(
self.iter_idx_name, 0, dtype='int32')
return index_init_node return index_init_node
def _build_var_shape_assign_node(self): def _build_var_shape_assign_node(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册