From eb9ef8850c88c63ca061006a2d7250de6e41922e Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Tue, 26 Oct 2021 14:08:25 +0800 Subject: [PATCH] Modify paddle.static.nn.cond doc (#36694) Update `cond` English document --- python/paddle/fluid/layers/control_flow.py | 41 ++++++++++++---------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index f444b5e9c0e..af2316a9a44 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2316,10 +2316,13 @@ def cond(pred, true_fn=None, false_fn=None, name=None): the same shape because of dataflow model of PaddlePaddle while the tensors in the tuples or the lists can have different shapes. - 2. Any tensors or operations created outside of ``true_fn`` and - ``false_fn`` will be executed regardless of which branch is selected at - runtime. This has frequently surprised users who expected a lazy - semantics. For example: + 2. This API could be used under both static mode or dygraph mode. If it + is in dygraph mode, the API only runs one branch based on condition. + + 3. If it is in static mode, any tensors or operations created outside + or inside of ``true_fn`` and ``false_fn`` will be in net building + regardless of which branch is selected at runtime. This has frequently + surprised users who expected a lazy semantics. For example: .. code-block:: python @@ -2328,9 +2331,11 @@ def cond(pred, true_fn=None, false_fn=None, name=None): a = paddle.zeros((1, 1)) b = paddle.zeros((1, 1)) c = a * b - out = paddle.nn.cond(a < b, lambda: a + c, lambda: b * b) + out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b) - No matter whether ``a < b`` , ``c = a * b`` will run. + No matter whether ``a < b`` , ``c = a * b`` will be in net building and + run. ``a + c`` and ``b * b`` will be in net building, but only one + branch will be executed during runtime. Args: pred(Tensor): A boolean tensor whose numel should be 1. The boolean @@ -2366,24 +2371,24 @@ def cond(pred, true_fn=None, false_fn=None, name=None): # return 3, 2 # - def true_func(): - return paddle.fill_constant(shape=[1, 2], dtype='int32', - value=1), paddle.fill_constant(shape=[2, 3], - dtype='bool', - value=True) + return paddle.full(shape=[1, 2], dtype='int32', + fill_value=1), paddle.full(shape=[2, 3], + dtype='bool', + fill_value=True) def false_func(): - return paddle.fill_constant(shape=[3, 4], dtype='float32', - value=3), paddle.fill_constant(shape=[4, 5], - dtype='int64', - value=2) + return paddle.full(shape=[3, 4], dtype='float32', + fill_value=3), paddle.full(shape=[4, 5], + dtype='int64', + fill_value=2) + - x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1) - y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23) + x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) pred = paddle.less_than(x=x, y=y, name=None) - ret = paddle.nn.cond(pred, true_func, false_func) + ret = paddle.static.nn.cond(pred, true_func, false_func) # ret is a tuple containing 2 tensors # ret[0] = [[1 1]] # ret[1] = [[ True True True] -- GitLab