未验证 提交 9ed59da4 编写于 作者: G guofei 提交者: GitHub

Modify english document and unittest of while_loop (#22615)

Modify english document and unittest of while_loop
上级 fc645d8a
...@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False. while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
Args: Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
body(Callable): A callable returning a tuple or list of tensors and LoDTensorArrays of the same arity as many arguments as ``loop_vars`` .
(length and structure) and types as ``loops_vars`` . body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
loop_vars(list|tuple): A list or tuple of tensors and LoDTensorArrays that is passed to both ``cond`` and ``body`` . (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False. is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None. refer to :ref:`api_guide_Name`. Default is None.
Returns: Returns:
A list or tuple of tensors and LoDTensorArrays which returned by ``body`` . A list or tuple of tensors or LoDTensorArrays which returned by ``body`` .
Returen type: Returen type:
list(Variable)|tuple(Variable). list(Variable)|tuple(Variable).
...@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
TypeError: If the type of ``cond`` returns is not a boolean variable. TypeError: If the type of ``cond`` returns is not a boolean variable.
TypeError: If the shape of ``cond`` returns is not equals 1. TypeError: If the shape of ``cond`` returns is not equals 1.
ValueError: If the ``var_loops`` is empty. ValueError: If the ``var_loops`` is empty.
ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
def cond(i): def cond(i, ten):
return layers.less_than(i, ten) return i < ten
def body(i): def body(i, ten):
return layers.increment(x=i, value=1, in_place=True) i = i + 1
return [i, ten]
main_program = fluid.default_main_program() main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program() startup_program = fluid.default_startup_program()
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
out = layers.while_loop(cond, body, [i]) i, ten = layers.while_loop(cond, body, [i, ten])
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=out) res = exe.run(main_program, feed={}, fetch_list=[i])
print(res) # [array([10])] print(res) # [array([10])]
""" """
helper = LayerHelper('while_loop', **locals()) helper = LayerHelper('while_loop', **locals())
...@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop_block = While(pre_cond, is_test, name) while_loop_block = While(pre_cond, is_test, name)
with while_loop_block.block(): with while_loop_block.block():
output_vars = body(*loop_vars) output_vars = body(*loop_vars)
map_structure(assign, output_vars, loop_vars) if not isinstance(output_vars, (list, tuple)):
if len(loop_vars) == 1: output_vars = [output_vars]
now_cond = cond(output_vars) if len(output_vars) != len(loop_vars):
else: raise ValueError("body in while_loop should return the same arity "
"(length and structure) and types as loop_vars")
now_cond = cond(*output_vars) now_cond = cond(*output_vars)
map_structure(assign, output_vars, loop_vars)
assign(now_cond, pre_cond) assign(now_cond, pre_cond)
return loop_vars return loop_vars
......
...@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase): ...@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase):
def cond_returns_2d_tensor(i): def cond_returns_2d_tensor(i):
return layers.less_than(i, ten_2d) return layers.less_than(i, ten_2d)
def cond_receives_two_args(i, ten):
return layers.less_than(i, ten)
def body(i): def body(i):
return layers.increment(i) return layers.increment(i)
def body_returns_error_length(i):
i = layers.increment(i)
return [i, i]
def body_returns_error_type(i, ten):
return layers.increment(i)
main_program = Program() main_program = Program()
startup_program = Program() startup_program = Program()
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
...@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase): ...@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase):
self.assertRaises(TypeError, type_error_shape_cond_returns_2d) self.assertRaises(TypeError, type_error_shape_cond_returns_2d)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_length():
out = layers.while_loop(cond_returns_bool_tensor,
body_returns_error_length, [data])
self.assertRaises(ValueError, value_error_body_returns_error_length)
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_type():
out = layers.while_loop(cond_receives_two_args,
body_returns_error_type, [data, ten])
self.assertRaises(ValueError, value_error_body_returns_error_type)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册