From 46f9184aff563a4b704a5f1dc23e3e3df0c87beb Mon Sep 17 00:00:00 2001 From: guofei <52460041+gfwm2013@users.noreply.github.com> Date: Tue, 24 Dec 2019 09:58:59 +0800 Subject: [PATCH] Modify the while_loop API (#21844) --- paddle/fluid/operators/controlflow/while_op_helper.cc | 2 +- python/paddle/fluid/layers/control_flow.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index e9a7dc43828..6ac41af8326 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -208,7 +208,7 @@ bool GetCondData(const framework::LoDTensor &cond) { framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get()); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "This version of PaddlePaddle doen NOT support GPU but got GPU tensor " + "This version of PaddlePaddle does NOT support GPU but got GPU tensor " "Cond in WhileOp. Please compile WITH_GPU option")); #endif return cpu_cond->data()[0]; diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index c24704bfadf..72a515e5707 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -829,7 +829,7 @@ class While(object): Args: cond(Variable): A Tensor whose data type is bool controlling whether to continue looping. - is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is None. + is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Examples: @@ -919,7 +919,7 @@ class While(object): "is_test": self.is_test}) -def while_loop(cond, body, loop_vars, name=None): +def while_loop(cond, body, loop_vars, is_test=False, name=None): """ while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False. @@ -928,6 +928,7 @@ def while_loop(cond, body, loop_vars, name=None): body(Callable): A callable returning a tuple or list of tensors of the same arity (length and structure) and types as ``loops_vars`` . loop_vars(list|tuple): A list or tuple of tensors that is passed to both ``cond`` and ``body`` . + is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False. name(str, optional): Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`. Default is None. @@ -991,7 +992,7 @@ def while_loop(cond, body, loop_vars, name=None): "the shape of the variable returned by cond should be []," "but given shape as {0}.".format(list(pre_cond.shape))) - while_loop_block = While(pre_cond) + while_loop_block = While(pre_cond, is_test, name) with while_loop_block.block(): output_vars = body(*loop_vars) if len(loop_vars) == 1: -- GitLab