Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9ed59da4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ed59da4
编写于
2月 17, 2020
作者:
G
guofei
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify english document and unittest of while_loop (#22615)
Modify english document and unittest of while_loop
上级
fc645d8a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
44 addition
and
15 deletion
+44
-15
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+20
-15
python/paddle/fluid/tests/unittests/test_while_loop_op.py
python/paddle/fluid/tests/unittests/test_while_loop_op.py
+24
-0
未找到文件。
python/paddle/fluid/layers/control_flow.py
浏览文件 @
9ed59da4
...
@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
...
@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
Args:
Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping.
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
body(Callable): A callable returning a tuple or list of tensors and LoDTensorArrays of the same arity
as many arguments as ``loop_vars`` .
(length and structure) and types as ``loops_vars`` .
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
loop_vars(list|tuple): A list or tuple of tensors and LoDTensorArrays that is passed to both ``cond`` and ``body`` .
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please
name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None.
refer to :ref:`api_guide_Name`. Default is None.
Returns:
Returns:
A list or tuple of tensors
and
LoDTensorArrays which returned by ``body`` .
A list or tuple of tensors
or
LoDTensorArrays which returned by ``body`` .
Returen type:
Returen type:
list(Variable)|tuple(Variable).
list(Variable)|tuple(Variable).
...
@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
...
@@ -951,6 +952,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
TypeError: If the type of ``cond`` returns is not a boolean variable.
TypeError: If the type of ``cond`` returns is not a boolean variable.
TypeError: If the shape of ``cond`` returns is not equals 1.
TypeError: If the shape of ``cond`` returns is not equals 1.
ValueError: If the ``var_loops`` is empty.
ValueError: If the ``var_loops`` is empty.
ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
...
@@ -958,21 +960,22 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
import paddle.fluid as fluid
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.layers as layers
def cond(i):
def cond(i
, ten
):
return
layers.less_than(i, ten)
return
i < ten
def body(i):
def body(i, ten):
return layers.increment(x=i, value=1, in_place=True)
i = i + 1
return [i, ten]
main_program = fluid.default_main_program()
main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(main_program, startup_program):
with fluid.program_guard(main_program, startup_program):
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
out = layers.while_loop(cond, body, [i
])
i, ten = layers.while_loop(cond, body, [i, ten
])
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=
out
)
res = exe.run(main_program, feed={}, fetch_list=
[i]
)
print(res) # [array([10])]
print(res) # [array([10])]
"""
"""
helper
=
LayerHelper
(
'while_loop'
,
**
locals
())
helper
=
LayerHelper
(
'while_loop'
,
**
locals
())
...
@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
...
@@ -999,11 +1002,13 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
while_loop_block
=
While
(
pre_cond
,
is_test
,
name
)
while_loop_block
=
While
(
pre_cond
,
is_test
,
name
)
with
while_loop_block
.
block
():
with
while_loop_block
.
block
():
output_vars
=
body
(
*
loop_vars
)
output_vars
=
body
(
*
loop_vars
)
map_structure
(
assign
,
output_vars
,
loop_vars
)
if
not
isinstance
(
output_vars
,
(
list
,
tuple
)):
if
len
(
loop_vars
)
==
1
:
output_vars
=
[
output_vars
]
now_cond
=
cond
(
output_vars
)
if
len
(
output_vars
)
!=
len
(
loop_vars
):
else
:
raise
ValueError
(
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars"
)
now_cond
=
cond
(
*
output_vars
)
now_cond
=
cond
(
*
output_vars
)
map_structure
(
assign
,
output_vars
,
loop_vars
)
assign
(
now_cond
,
pre_cond
)
assign
(
now_cond
,
pre_cond
)
return
loop_vars
return
loop_vars
...
...
python/paddle/fluid/tests/unittests/test_while_loop_op.py
浏览文件 @
9ed59da4
...
@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase):
...
@@ -311,9 +311,19 @@ class TestApiWhileLoop_Error(unittest.TestCase):
def
cond_returns_2d_tensor
(
i
):
def
cond_returns_2d_tensor
(
i
):
return
layers
.
less_than
(
i
,
ten_2d
)
return
layers
.
less_than
(
i
,
ten_2d
)
def
cond_receives_two_args
(
i
,
ten
):
return
layers
.
less_than
(
i
,
ten
)
def
body
(
i
):
def
body
(
i
):
return
layers
.
increment
(
i
)
return
layers
.
increment
(
i
)
def
body_returns_error_length
(
i
):
i
=
layers
.
increment
(
i
)
return
[
i
,
i
]
def
body_returns_error_type
(
i
,
ten
):
return
layers
.
increment
(
i
)
main_program
=
Program
()
main_program
=
Program
()
startup_program
=
Program
()
startup_program
=
Program
()
with
program_guard
(
main_program
,
startup_program
):
with
program_guard
(
main_program
,
startup_program
):
...
@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase):
...
@@ -367,6 +377,20 @@ class TestApiWhileLoop_Error(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
type_error_shape_cond_returns_2d
)
self
.
assertRaises
(
TypeError
,
type_error_shape_cond_returns_2d
)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def
value_error_body_returns_error_length
():
out
=
layers
.
while_loop
(
cond_returns_bool_tensor
,
body_returns_error_length
,
[
data
])
self
.
assertRaises
(
ValueError
,
value_error_body_returns_error_length
)
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def
value_error_body_returns_error_type
():
out
=
layers
.
while_loop
(
cond_receives_two_args
,
body_returns_error_type
,
[
data
,
ten
])
self
.
assertRaises
(
ValueError
,
value_error_body_returns_error_type
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录