From 60aaa7158b3782da56e2c638a43100b731d7331b Mon Sep 17 00:00:00 2001 From: songyouwei Date: Tue, 25 Feb 2020 23:18:43 +0800 Subject: [PATCH] dygraph support while_loop op (#22714) * dygraph support while_loop op test=develop * refine assign test=develop --- python/paddle/fluid/layers/control_flow.py | 16 ++++++++- .../fluid/tests/unittests/test_layers.py | 35 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 140219c28f..6bc9d29a90 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager from .layer_function_generator import autodoc, templatedoc from .tensor import assign, cast, fill_constant from .. import core -from ..framework import Program, Variable, Operator +from ..framework import Program, Variable, Operator, in_dygraph_mode from ..layer_helper import LayerHelper, unique_name from .nn import logical_and, logical_not, logical_or from .utils import assert_same_structure, map_structure @@ -999,6 +999,20 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): "the shape of the variable returned by cond should be []," "but given shape as {0}.".format(list(pre_cond.shape))) + if in_dygraph_mode(): + now_cond = pre_cond.numpy()[0] + while (now_cond): + output_vars = body(*loop_vars) + if not isinstance(output_vars, (list, tuple)): + output_vars = [output_vars] + if len(output_vars) != len(loop_vars): + raise ValueError( + "body in while_loop should return the same arity " + "(length and structure) and types as loop_vars") + now_cond = cond(*output_vars).numpy()[0] + loop_vars = output_vars + return loop_vars + while_loop_block = While(pre_cond, is_test, name) with while_loop_block.block(): output_vars = body(*loop_vars) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 967aa4789f..91b17cc527 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1393,6 +1393,41 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_ret, dy_ret_rlt)) + def test_while_loop(self): + with self.static_graph(): + i = layers.fill_constant(shape=[1], dtype='int64', value=0) + ten = layers.fill_constant(shape=[1], dtype='int64', value=10) + + def cond(i): + return layers.less_than(i, ten) + + def body(i): + return i + 1 + + out = layers.while_loop(cond, body, [i]) + static_ret = self.get_static_graph_result(feed={}, fetch_list=out) + + with self.dynamic_graph(): + i = layers.fill_constant(shape=[1], dtype='int64', value=0) + ten = layers.fill_constant(shape=[1], dtype='int64', value=10) + + def cond(i): + return layers.less_than(i, ten) + + def body(i): + return i + 1 + + dy_ret = layers.while_loop(cond, body, [i]) + with self.assertRaises(ValueError): + j = layers.fill_constant(shape=[1], dtype='int64', value=0) + + def body2(i): + return i + 1, i + 2 + + layers.while_loop(cond, body2, [j]) + + self.assertTrue(np.array_equal(static_ret[0], dy_ret[0].numpy())) + def test_compare(self): value_a = np.arange(3) value_b = np.arange(3) -- GitLab