diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 9859feb9d9079219f79e4bdbf2bb33e766021fcc..4e5a3f7b7085137fbe59bc0dc362f7d21e7bc75a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -693,7 +693,7 @@ class LoopTransformer(gast.NodeTransformer): new_body = node.body new_body.append( gast.Return(value=generate_name_node( - loop_var_names, ctx=gast.Load()))) + loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(WHILE_BODY_PREFIX), args=gast.arguments( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 8116c04f2034fe3b6d1070f4ef4b067b314be64c..747ed8f5dfd42729c50b746c7a119b1e54c68d4d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -121,6 +121,13 @@ def for_loop_dyfunc_not_support(max_len): return ret +def for_break_single_return(max_len): + for i in range(3): + if i == 2: + break + return i + + def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) @@ -297,7 +304,10 @@ class TestTransformWhileLoop(unittest.TestCase): ret = declarative(self.dyfunc)(tensor_x) else: ret = self.dyfunc(tensor_x) - return ret.numpy() + if hasattr(ret, "numpy"): + return ret.numpy() + else: + return ret def test_ast_to_func(self): static_numpy = self._run_static() @@ -320,6 +330,11 @@ class TestTransformWhileLoopWithNone(TestTransformWhileLoop): self.dyfunc = while_loop_dyfunc_with_none +class TestForBreakSingleReturn(TestTransformWhileLoop): + def _init_dyfunc(self): + self.dyfunc = for_break_single_return + + class TestWhileLoopBoolOp(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_bool_op