diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index b85a2137dad812509ad970282a34c142848d28e0..7bce234168c7eb6baa3d7b94eacbbe343ad6de5a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -20,7 +20,7 @@ from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor -from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node __all__ = ['BreakContinueTransformer'] @@ -140,7 +140,7 @@ class BreakContinueTransformer(BaseNodeVisitor): self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'break' add break into condition of the loop. - assign_false_node = create_fill_constant_node(variable_name, False) + assign_false_node = create_bool_node(variable_name, False) self._add_stmt_before_cur_node(loop_node_index, assign_false_node) cond_var_node = gast.UnaryOp(op=gast.Not(), @@ -177,7 +177,7 @@ class BreakContinueTransformer(BaseNodeVisitor): self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'continue', set continue to False at the beginning of each loop - assign_false_node = create_fill_constant_node(variable_name, False) + assign_false_node = create_bool_node(variable_name, False) loop_node.body.insert(0, assign_false_node) def _remove_stmts_after_break_continue(self, break_continue_node, @@ -221,7 +221,7 @@ class BreakContinueTransformer(BaseNodeVisitor): i = index_in_list(stmt_list, break_continue_node) if i == -1: return False - assign_true_node = create_fill_constant_node(break_continue_name, True) + assign_true_node = create_bool_node(break_continue_name, True) stmt_list[i:] = [assign_true_node] return True diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 9bbce59fc54cefdf5d7cb2e74140e7f7b073afe9..28d7cff8cb0ca30cc7743f777c2ab54e977de98b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -76,3 +76,12 @@ def create_bool_as_type(x, value=True): return paddle.full(shape=[1], fill_value=value, dtype="bool") else: return value + + +def create_bool_node(name, value): + ''' + Create a assign stmt for name = value . + ''' + assert isinstance(value, bool) + node = "{} = {}".format(name, value) + return gast.parse(node).body[0] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 79b6880b0d871afae5dd66a6c8645d6aeeecf36e..6b4b2d46a12f6f9cd1847953f999d18050b5f07e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -101,7 +101,11 @@ def test_break_continue_in_for(x): x += 10086 a = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) - for i in range(1, 10, 1): + b = fluid.layers.fill_constant(shape=[1], dtype='int32', value=3) + # b = 10 + # TODO: add Raise Error and suggestion for usage: + # Py for contains break/continue depends on control-flow. + for i in range(b): if a <= 4: x += 1 a += 1