未验证 提交 eb9ef885 编写于 作者: H Huihuang Zheng 提交者: GitHub

Modify paddle.static.nn.cond doc (#36694)

Update `cond` English document
上级 9aeca2f1
...@@ -2316,10 +2316,13 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -2316,10 +2316,13 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
the same shape because of dataflow model of PaddlePaddle while the the same shape because of dataflow model of PaddlePaddle while the
tensors in the tuples or the lists can have different shapes. tensors in the tuples or the lists can have different shapes.
2. Any tensors or operations created outside of ``true_fn`` and 2. This API could be used under both static mode or dygraph mode. If it
``false_fn`` will be executed regardless of which branch is selected at is in dygraph mode, the API only runs one branch based on condition.
runtime. This has frequently surprised users who expected a lazy
semantics. For example: 3. If it is in static mode, any tensors or operations created outside
or inside of ``true_fn`` and ``false_fn`` will be in net building
regardless of which branch is selected at runtime. This has frequently
surprised users who expected a lazy semantics. For example:
.. code-block:: python .. code-block:: python
...@@ -2328,9 +2331,11 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -2328,9 +2331,11 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
a = paddle.zeros((1, 1)) a = paddle.zeros((1, 1))
b = paddle.zeros((1, 1)) b = paddle.zeros((1, 1))
c = a * b c = a * b
out = paddle.nn.cond(a < b, lambda: a + c, lambda: b * b) out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b)
No matter whether ``a < b`` , ``c = a * b`` will run. No matter whether ``a < b`` , ``c = a * b`` will be in net building and
run. ``a + c`` and ``b * b`` will be in net building, but only one
branch will be executed during runtime.
Args: Args:
pred(Tensor): A boolean tensor whose numel should be 1. The boolean pred(Tensor): A boolean tensor whose numel should be 1. The boolean
...@@ -2366,24 +2371,24 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -2366,24 +2371,24 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
# return 3, 2 # return 3, 2
# #
def true_func(): def true_func():
return paddle.fill_constant(shape=[1, 2], dtype='int32', return paddle.full(shape=[1, 2], dtype='int32',
value=1), paddle.fill_constant(shape=[2, 3], fill_value=1), paddle.full(shape=[2, 3],
dtype='bool', dtype='bool',
value=True) fill_value=True)
def false_func(): def false_func():
return paddle.fill_constant(shape=[3, 4], dtype='float32', return paddle.full(shape=[3, 4], dtype='float32',
value=3), paddle.fill_constant(shape=[4, 5], fill_value=3), paddle.full(shape=[4, 5],
dtype='int64', dtype='int64',
value=2) fill_value=2)
x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1) x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23) y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
pred = paddle.less_than(x=x, y=y, name=None) pred = paddle.less_than(x=x, y=y, name=None)
ret = paddle.nn.cond(pred, true_func, false_func) ret = paddle.static.nn.cond(pred, true_func, false_func)
# ret is a tuple containing 2 tensors # ret is a tuple containing 2 tensors
# ret[0] = [[1 1]] # ret[0] = [[1 1]]
# ret[1] = [[ True True True] # ret[1] = [[ True True True]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册