diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index a997afbf4d35f3365605b264ac144a96e16394c7..c24704bfadf883aab44ba3af1fa594a36e36bad8 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1851,24 +1851,45 @@ def cond(pred, true_fn=None, false_fn=None, name=None): list of tensors. Note: - The tuples or lists in ``true_fn`` and ``false_fn`` must have same - shape because of dataflow model of PaddlePaddle while the tensors in the - tuples or the lists can have different shapes. + 1. The tuples or lists returned by ``true_fn`` and ``false_fn`` must have + the same shape because of dataflow model of PaddlePaddle while the + tensors in the tuples or the lists can have different shapes. + + 2. Any tensors or operations created outside of ``true_fn`` and + ``false_fn`` will be executed regardless of which branch is selected at + runtime. This has frequently surprised users who expected a lazy + semantics. For example: + + .. code-block:: python + + import paddle.fluid as fluid + a = fluid.data(name='a', shape=[-1, 1], dtype='float32') + b = fluid.data(name='b', shape=[-1, 1], dtype='float32') + c = a * b + out = fluid.layers.cond(a < b, lambda: a + c, lambda: b * b) + + No matter whether ``a < b`` , ``c = a * b`` will run. Args: pred(Variable): A boolean tensor whose numel should be 1. The boolean value determines whether to return the result of ``true_fn`` or - ``false_fn`` - true_fn(callable): A callable to be performed if ``pred`` is true - false_fn(callable): A callable to be performed if ``pred`` is false - name(str, optional): The default value is ``None``. Normally users + ``false_fn`` . + true_fn(callable, optional): A callable to be performed if ``pred`` is + true. The default value is ``None`` . + false_fn(callable, optional): A callable to be performed if ``pred`` is + false. The default value is ``None`` . + name(str, optional): The default value is ``None`` . Normally users don't have to set this parameter. For more information, please - refer to :ref:`api_guide_Name`. + refer to :ref:`api_guide_Name` . + + Returns: + Variable|list(Variable)|tuple(Variable): returns ``true_fn()`` if the + predicate ``pred`` is true else ``false_fn()`` . Raises: TypeError: if ``true_fn`` or ``false_fn`` is not callable. - ValueError: if ``true_fn`` and ``false_fn`` doesn't return the same - nest structure of tensors. + ValueError: if ``true_fn`` and ``false_fn`` don't return the same nest + structure of tensors. Examples: .. code-block:: python