未验证 提交 9ed59da4 编写于 作者: G guofei 提交者: GitHub

Modify english document and unittest of while_loop (#22615)

Modify english document and unittest of while_loop
上级 fc645d8a
......@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping.
body(Callable): A callable returning a tuple or list of tensors and LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` .
loop_vars(list|tuple): A list or tuple of tensors and LoDTensorArrays that is passed to both ``cond`` and ``body`` .
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
as many arguments as ``loop_vars`` .
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None.
Returns:
A list or tuple of tensors and LoDTensorArrays which returned by ``body`` .
A list or tuple of tensors or LoDTensorArrays which returned by ``body`` .
Returen type:
list(Variable)|tuple(Variable).
......@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
TypeError: If the type of ``cond`` returns is not a boolean variable.
TypeError: If the shape of ``cond`` returns is not equals 1.
ValueError: If the ``var_loops`` is empty.
ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
Examples:
.. code-block:: python
......@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def cond(i):
return layers.less_than(i, ten)
def cond(i, ten):
return i < ten
def body(i):
return layers.increment(x=i, value=1, in_place=True)
def body(i, ten):
i = i + 1
return [i, ten]
main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(main_program, startup_program):
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
out = layers.while_loop(cond, body, [i])
i, ten = layers.while_loop(cond, body, [i, ten])
exe = fluid.Executor(fluid.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=out)
res = exe.run(main_program, feed={}, fetch_list=[i])
print(res) # [array([10])]
"""
helper = LayerHelper('while_loop', **locals())
......@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop_block = While(pre_cond, is_test, name)
with while_loop_block.block():
output_vars = body(*loop_vars)
map_structure(assign, output_vars, loop_vars)
if len(loop_vars) == 1:
now_cond = cond(output_vars)
else:
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
raise ValueError("body in while_loop should return the same arity "
"(length and structure) and types as loop_vars")
now_cond = cond(*output_vars)
map_structure(assign, output_vars, loop_vars)
assign(now_cond, pre_cond)
return loop_vars
......
......@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase):
def cond_returns_2d_tensor(i):
return layers.less_than(i, ten_2d)
def cond_receives_two_args(i, ten):
return layers.less_than(i, ten)
def body(i):
return layers.increment(i)
def body_returns_error_length(i):
i = layers.increment(i)
return [i, i]
def body_returns_error_type(i, ten):
return layers.increment(i)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
......@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase):
self.assertRaises(TypeError, type_error_shape_cond_returns_2d)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_length():
out = layers.while_loop(cond_returns_bool_tensor,
body_returns_error_length, [data])
self.assertRaises(ValueError, value_error_body_returns_error_length)
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_type():
out = layers.while_loop(cond_receives_two_args,
body_returns_error_type, [data, ten])
self.assertRaises(ValueError, value_error_body_returns_error_type)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册