From e388e603125f72852e073dbd9dbeae790c46f739 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Mon, 12 Oct 2020 20:12:47 +0800 Subject: [PATCH] Refine cond API English Doc for 2.0RC (#27708) As the title --- python/paddle/fluid/layers/control_flow.py | 63 +++++++++------------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 411ac6e51b..0c77917c78 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2297,11 +2297,6 @@ def copy_var_to_parent_block(var, layer_helper): def cond(pred, true_fn=None, false_fn=None, name=None): """ - :api_attr: Static Graph - :alias_main: paddle.nn.cond - :alias: paddle.nn.cond,paddle.nn.control_flow.cond - :old_api: paddle.fluid.layers.cond - This API returns ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . Users could also set ``true_fn`` or ``false_fn`` to ``None`` if do nothing and this API will treat the callable simply returns @@ -2323,17 +2318,18 @@ def cond(pred, true_fn=None, false_fn=None, name=None): semantics. For example: .. code-block:: python - - import paddle.fluid as fluid - a = fluid.data(name='a', shape=[-1, 1], dtype='float32') - b = fluid.data(name='b', shape=[-1, 1], dtype='float32') + + import paddle + + a = paddle.zeros((1, 1)) + b = paddle.zeros((1, 1)) c = a * b - out = fluid.layers.cond(a < b, lambda: a + c, lambda: b * b) + out = paddle.nn.cond(a < b, lambda: a + c, lambda: b * b) No matter whether ``a < b`` , ``c = a * b`` will run. Args: - pred(Variable): A boolean tensor whose numel should be 1. The boolean + pred(Tensor): A boolean tensor whose numel should be 1. The boolean value determines whether to return the result of ``true_fn`` or ``false_fn`` . true_fn(callable, optional): A callable to be performed if ``pred`` is @@ -2345,7 +2341,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): refer to :ref:`api_guide_Name` . Returns: - Variable|list(Variable)|tuple(Variable): returns ``true_fn()`` if the + Tensor|list(Tensor)|tuple(Tensor): returns ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . Raises: @@ -2356,10 +2352,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.fluid.layers as layers - from paddle.fluid.executor import Executor - from paddle.fluid.framework import Program, program_guard + import paddle # # pseudocode: @@ -2369,32 +2362,28 @@ def cond(pred, true_fn=None, false_fn=None, name=None): # return 3, 2 # + def true_func(): - return layers.fill_constant( - shape=[1, 2], dtype='int32', value=1), layers.fill_constant( - shape=[2, 3], dtype='bool', value=True) + return paddle.fill_constant(shape=[1, 2], dtype='int32', + value=1), paddle.fill_constant(shape=[2, 3], + dtype='bool', + value=True) + def false_func(): - return layers.fill_constant( - shape=[3, 4], dtype='float32', value=3), layers.fill_constant( - shape=[4, 5], dtype='int64', value=2) - - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): - x = layers.fill_constant(shape=[1], dtype='float32', value=0.1) - y = layers.fill_constant(shape=[1], dtype='float32', value=0.23) - pred = layers.less_than(x, y) - out = layers.cond(pred, true_func, false_func) - # out is a tuple containing 2 tensors - - place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( - ) else fluid.CPUPlace() - exe = fluid.Executor(place) - ret = exe.run(main_program, fetch_list=out) + return paddle.fill_constant(shape=[3, 4], dtype='float32', + value=3), paddle.fill_constant(shape=[4, 5], + dtype='int64', + value=2) + + x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1) + y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23) + pred = paddle.less_than(x=x, y=y, name=None) + ret = paddle.nn.cond(pred, true_func, false_func) + # ret is a tuple containing 2 tensors # ret[0] = [[1 1]] # ret[1] = [[ True True True] - # [ True True True]] + # [ True True True]] """ if in_dygraph_mode(): -- GitLab