From cf8d42bb03fb2c3f69f10bab2d898b99115ac0ea Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 1 Jul 2022 14:03:53 +0800 Subject: [PATCH] [Dy2Stat]Polish break/continue statement transformer logic (#43489) * [Dy2Stat]Polish break/continue statement transformer logic --- .../dygraph_to_static/break_continue_transformer.py | 8 ++++---- .../dygraph/dygraph_to_static/variable_trans_func.py | 9 +++++++++ .../unittests/dygraph_to_static/test_break_continue.py | 6 +++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index b85a2137da..7bce234168 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -20,7 +20,7 @@ from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor -from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node __all__ = ['BreakContinueTransformer'] @@ -140,7 +140,7 @@ class BreakContinueTransformer(BaseNodeVisitor): self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'break' add break into condition of the loop. - assign_false_node = create_fill_constant_node(variable_name, False) + assign_false_node = create_bool_node(variable_name, False) self._add_stmt_before_cur_node(loop_node_index, assign_false_node) cond_var_node = gast.UnaryOp(op=gast.Not(), @@ -177,7 +177,7 @@ class BreakContinueTransformer(BaseNodeVisitor): self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'continue', set continue to False at the beginning of each loop - assign_false_node = create_fill_constant_node(variable_name, False) + assign_false_node = create_bool_node(variable_name, False) loop_node.body.insert(0, assign_false_node) def _remove_stmts_after_break_continue(self, break_continue_node, @@ -221,7 +221,7 @@ class BreakContinueTransformer(BaseNodeVisitor): i = index_in_list(stmt_list, break_continue_node) if i == -1: return False - assign_true_node = create_fill_constant_node(break_continue_name, True) + assign_true_node = create_bool_node(break_continue_name, True) stmt_list[i:] = [assign_true_node] return True diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 9bbce59fc5..28d7cff8cb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -76,3 +76,12 @@ def create_bool_as_type(x, value=True): return paddle.full(shape=[1], fill_value=value, dtype="bool") else: return value + + +def create_bool_node(name, value): + ''' + Create a assign stmt for name = value . + ''' + assert isinstance(value, bool) + node = "{} = {}".format(name, value) + return gast.parse(node).body[0] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py index 79b6880b0d..6b4b2d46a1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py @@ -101,7 +101,11 @@ def test_break_continue_in_for(x): x += 10086 a = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) - for i in range(1, 10, 1): + b = fluid.layers.fill_constant(shape=[1], dtype='int32', value=3) + # b = 10 + # TODO: add Raise Error and suggestion for usage: + # Py for contains break/continue depends on control-flow. + for i in range(b): if a <= 4: x += 1 a += 1 -- GitLab