diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 425d43e8f1e38f7c82e52479db25b537fe60f39e..14dda7a0ea4c22f07288fb220c5ac58668bc428c 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False. Args: - cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. - body(Callable): A callable returning a tuple or list of tensors and LoDTensorArrays of the same arity - (length and structure) and types as ``loops_vars`` . - loop_vars(list|tuple): A list or tuple of tensors and LoDTensorArrays that is passed to both ``cond`` and ``body`` . + cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes + as many arguments as ``loop_vars`` . + body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity + (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` . + loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` . is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False. name(str, optional): Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`. Default is None. Returns: - A list or tuple of tensors and LoDTensorArrays which returned by ``body`` . + A list or tuple of tensors or LoDTensorArrays which returned by ``body`` . Returen type: list(Variable)|tuple(Variable). @@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): TypeError: If the type of ``cond`` returns is not a boolean variable. TypeError: If the shape of ``cond`` returns is not equals 1. ValueError: If the ``var_loops`` is empty. + ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``. Examples: .. code-block:: python @@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): import paddle.fluid as fluid import paddle.fluid.layers as layers - def cond(i): - return layers.less_than(i, ten) + def cond(i, ten): + return i < ten - def body(i): - return layers.increment(x=i, value=1, in_place=True) + def body(i, ten): + i = i + 1 + return [i, ten] main_program = fluid.default_main_program() startup_program = fluid.default_startup_program() with fluid.program_guard(main_program, startup_program): i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length - out = layers.while_loop(cond, body, [i]) + i, ten = layers.while_loop(cond, body, [i, ten]) exe = fluid.Executor(fluid.CPUPlace()) - res = exe.run(main_program, feed={}, fetch_list=out) + res = exe.run(main_program, feed={}, fetch_list=[i]) print(res) # [array([10])] """ helper = LayerHelper('while_loop', **locals()) @@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): while_loop_block = While(pre_cond, is_test, name) with while_loop_block.block(): output_vars = body(*loop_vars) + if not isinstance(output_vars, (list, tuple)): + output_vars = [output_vars] + if len(output_vars) != len(loop_vars): + raise ValueError("body in while_loop should return the same arity " + "(length and structure) and types as loop_vars") + now_cond = cond(*output_vars) map_structure(assign, output_vars, loop_vars) - if len(loop_vars) == 1: - now_cond = cond(output_vars) - else: - now_cond = cond(*output_vars) assign(now_cond, pre_cond) return loop_vars diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index ab2e2bd3a9427aa3cea8cff80eeda0048ce88f2e..6d86f604a1eb0dd1d2e568fe139e3082900ae0ea 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase): def cond_returns_2d_tensor(i): return layers.less_than(i, ten_2d) + def cond_receives_two_args(i, ten): + return layers.less_than(i, ten) + def body(i): return layers.increment(i) + def body_returns_error_length(i): + i = layers.increment(i) + return [i, i] + + def body_returns_error_type(i, ten): + return layers.increment(i) + main_program = Program() startup_program = Program() with program_guard(main_program, startup_program): @@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase): self.assertRaises(TypeError, type_error_shape_cond_returns_2d) + # The length of `body` returns in Op(while_loop) must be same as `loop_vars` + def value_error_body_returns_error_length(): + out = layers.while_loop(cond_returns_bool_tensor, + body_returns_error_length, [data]) + + self.assertRaises(ValueError, value_error_body_returns_error_length) + + # The type of `body` returns in Op(while_loop) must be same as `loop_vars` + def value_error_body_returns_error_type(): + out = layers.while_loop(cond_receives_two_args, + body_returns_error_type, [data, ten]) + + self.assertRaises(ValueError, value_error_body_returns_error_type) + if __name__ == '__main__': unittest.main()