提交 46f9184a 编写于 作者: G guofei 提交者: Huihuang Zheng

Modify the while_loop API (#21844)

上级 682ca642
...@@ -208,7 +208,7 @@ bool GetCondData(const framework::LoDTensor &cond) { ...@@ -208,7 +208,7 @@ bool GetCondData(const framework::LoDTensor &cond) {
framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get()); framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"This version of PaddlePaddle doen NOT support GPU but got GPU tensor " "This version of PaddlePaddle does NOT support GPU but got GPU tensor "
"Cond in WhileOp. Please compile WITH_GPU option")); "Cond in WhileOp. Please compile WITH_GPU option"));
#endif #endif
return cpu_cond->data<bool>()[0]; return cpu_cond->data<bool>()[0];
......
...@@ -829,7 +829,7 @@ class While(object): ...@@ -829,7 +829,7 @@ class While(object):
Args: Args:
cond(Variable): A Tensor whose data type is bool controlling whether to continue looping. cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is None. is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Examples: Examples:
...@@ -919,7 +919,7 @@ class While(object): ...@@ -919,7 +919,7 @@ class While(object):
"is_test": self.is_test}) "is_test": self.is_test})
def while_loop(cond, body, loop_vars, name=None): def while_loop(cond, body, loop_vars, is_test=False, name=None):
""" """
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False. while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
...@@ -928,6 +928,7 @@ def while_loop(cond, body, loop_vars, name=None): ...@@ -928,6 +928,7 @@ def while_loop(cond, body, loop_vars, name=None):
body(Callable): A callable returning a tuple or list of tensors of the same arity (length and structure) body(Callable): A callable returning a tuple or list of tensors of the same arity (length and structure)
and types as ``loops_vars`` . and types as ``loops_vars`` .
loop_vars(list|tuple): A list or tuple of tensors that is passed to both ``cond`` and ``body`` . loop_vars(list|tuple): A list or tuple of tensors that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None. refer to :ref:`api_guide_Name`. Default is None.
...@@ -991,7 +992,7 @@ def while_loop(cond, body, loop_vars, name=None): ...@@ -991,7 +992,7 @@ def while_loop(cond, body, loop_vars, name=None):
"the shape of the variable returned by cond should be []," "the shape of the variable returned by cond should be [],"
"but given shape as {0}.".format(list(pre_cond.shape))) "but given shape as {0}.".format(list(pre_cond.shape)))
while_loop_block = While(pre_cond) while_loop_block = While(pre_cond, is_test, name)
with while_loop_block.block(): with while_loop_block.block():
output_vars = body(*loop_vars) output_vars = body(*loop_vars)
if len(loop_vars) == 1: if len(loop_vars) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册