Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9ed59da4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ed59da4
编写于
2月 17, 2020
作者:
G
guofei
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify english document and unittest of while_loop (#22615)
Modify english document and unittest of while_loop
上级
fc645d8a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
44 addition
and
15 deletion
+44
-15
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+20
-15
python/paddle/fluid/tests/unittests/test_while_loop_op.py
python/paddle/fluid/tests/unittests/test_while_loop_op.py
+24
-0
未找到文件。
python/paddle/fluid/layers/control_flow.py
浏览文件 @
9ed59da4
...
...
@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping.
body(Callable): A callable returning a tuple or list of tensors and LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` .
loop_vars(list|tuple): A list or tuple of tensors and LoDTensorArrays that is passed to both ``cond`` and ``body`` .
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
as many arguments as ``loop_vars`` .
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None.
Returns:
A list or tuple of tensors
and
LoDTensorArrays which returned by ``body`` .
A list or tuple of tensors
or
LoDTensorArrays which returned by ``body`` .
Returen type:
list(Variable)|tuple(Variable).
...
...
@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
TypeError: If the type of ``cond`` returns is not a boolean variable.
TypeError: If the shape of ``cond`` returns is not equals 1.
ValueError: If the ``var_loops`` is empty.
ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
Examples:
.. code-block:: python
...
...
@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def cond(i):
return
layers.less_than(i, ten)
def cond(i
, ten
):
return
i < ten
def body(i):
return layers.increment(x=i, value=1, in_place=True)
def body(i, ten):
i = i + 1
return [i, ten]
main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(main_program, startup_program):
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
out = layers.while_loop(cond, body, [i
])
i, ten = layers.while_loop(cond, body, [i, ten
])
exe = fluid.Executor(fluid.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=
out
)
res = exe.run(main_program, feed={}, fetch_list=
[i]
)
print(res) # [array([10])]
"""
helper
=
LayerHelper
(
'while_loop'
,
**
locals
())
...
...
@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop_block
=
While
(
pre_cond
,
is_test
,
name
)
with
while_loop_block
.
block
():
output_vars
=
body
(
*
loop_vars
)
map_structure
(
assign
,
output_vars
,
loop_vars
)
if
len
(
loop_vars
)
==
1
:
now_cond
=
cond
(
output_vars
)
else
:
if
not
isinstance
(
output_vars
,
(
list
,
tuple
)):
output_vars
=
[
output_vars
]
if
len
(
output_vars
)
!=
len
(
loop_vars
):
raise
ValueError
(
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars"
)
now_cond
=
cond
(
*
output_vars
)
map_structure
(
assign
,
output_vars
,
loop_vars
)
assign
(
now_cond
,
pre_cond
)
return
loop_vars
...
...
python/paddle/fluid/tests/unittests/test_while_loop_op.py
浏览文件 @
9ed59da4
...
...
@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase):
def
cond_returns_2d_tensor
(
i
):
return
layers
.
less_than
(
i
,
ten_2d
)
def
cond_receives_two_args
(
i
,
ten
):
return
layers
.
less_than
(
i
,
ten
)
def
body
(
i
):
return
layers
.
increment
(
i
)
def
body_returns_error_length
(
i
):
i
=
layers
.
increment
(
i
)
return
[
i
,
i
]
def
body_returns_error_type
(
i
,
ten
):
return
layers
.
increment
(
i
)
main_program
=
Program
()
startup_program
=
Program
()
with
program_guard
(
main_program
,
startup_program
):
...
...
@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
type_error_shape_cond_returns_2d
)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def
value_error_body_returns_error_length
():
out
=
layers
.
while_loop
(
cond_returns_bool_tensor
,
body_returns_error_length
,
[
data
])
self
.
assertRaises
(
ValueError
,
value_error_body_returns_error_length
)
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def
value_error_body_returns_error_type
():
out
=
layers
.
while_loop
(
cond_receives_two_args
,
body_returns_error_type
,
[
data
,
ten
])
self
.
assertRaises
(
ValueError
,
value_error_body_returns_error_type
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录