未验证 提交 cf8d42bb 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Polish break/continue statement transformer logic (#43489)

* [Dy2Stat]Polish break/continue statement transformer logic
上级 76156d12
......@@ -20,7 +20,7 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node
__all__ = ['BreakContinueTransformer']
......@@ -140,7 +140,7 @@ class BreakContinueTransformer(BaseNodeVisitor):
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
# 4. For 'break' add break into condition of the loop.
assign_false_node = create_fill_constant_node(variable_name, False)
assign_false_node = create_bool_node(variable_name, False)
self._add_stmt_before_cur_node(loop_node_index, assign_false_node)
cond_var_node = gast.UnaryOp(op=gast.Not(),
......@@ -177,7 +177,7 @@ class BreakContinueTransformer(BaseNodeVisitor):
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
# 4. For 'continue', set continue to False at the beginning of each loop
assign_false_node = create_fill_constant_node(variable_name, False)
assign_false_node = create_bool_node(variable_name, False)
loop_node.body.insert(0, assign_false_node)
def _remove_stmts_after_break_continue(self, break_continue_node,
......@@ -221,7 +221,7 @@ class BreakContinueTransformer(BaseNodeVisitor):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
assign_true_node = create_bool_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
......
......@@ -76,3 +76,12 @@ def create_bool_as_type(x, value=True):
return paddle.full(shape=[1], fill_value=value, dtype="bool")
else:
return value
def create_bool_node(name, value):
'''
Create a assign stmt for name = value .
'''
assert isinstance(value, bool)
node = "{} = {}".format(name, value)
return gast.parse(node).body[0]
......@@ -101,7 +101,11 @@ def test_break_continue_in_for(x):
x += 10086
a = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
for i in range(1, 10, 1):
b = fluid.layers.fill_constant(shape=[1], dtype='int32', value=3)
# b = 10
# TODO: add Raise Error and suggestion for usage:
# Py for contains break/continue depends on control-flow.
for i in range(b):
if a <= 4:
x += 1
a += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册