From 287cbde8614cea6a94cf9290365cd3d9dff72288 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 28 Mar 2022 10:36:20 +0800 Subject: [PATCH] [Dy2Stat] Fix ForLoop Transformation with single return (#40683) * [Dy2Stat] Fix ForLoop Transformation with single return * [Dy2Stat] Fix ForLoop Transformation with single return --- .../dygraph_to_static/loop_transformer.py | 2 +- .../unittests/dygraph_to_static/test_loop.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 9859feb9d90..4e5a3f7b708 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -693,7 +693,7 @@ class LoopTransformer(gast.NodeTransformer): new_body = node.body new_body.append( gast.Return(value=generate_name_node( - loop_var_names, ctx=gast.Load()))) + loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 8116c04f203..747ed8f5dfd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -121,6 +121,13 @@ def for_loop_dyfunc_not_support(max_len): return ret +def for_break_single_return(max_len): + for i in range(3): + if i == 2: + break + return i + + def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) @@ -297,7 +304,10 @@ class TestTransformWhileLoop(unittest.TestCase): ret = declarative(self.dyfunc)(tensor_x) else: ret = self.dyfunc(tensor_x) - return ret.numpy() + if hasattr(ret, "numpy"): + return ret.numpy() + else: + return ret def test_ast_to_func(self): static_numpy = self._run_static() @@ -320,6 +330,11 @@ class TestTransformWhileLoopWithNone(TestTransformWhileLoop): self.dyfunc = while_loop_dyfunc_with_none +class TestForBreakSingleReturn(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = for_break_single_return + + class TestWhileLoopBoolOp(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_bool_op -- GitLab