未验证 提交 c542d571 编写于 作者: H Huihuang Zheng 提交者: GitHub

Modify paddle.static.nn.cond doc (#36694) (#36767)

Update `cond` English document
上级 b080d986
......@@ -2316,10 +2316,13 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
the same shape because of dataflow model of PaddlePaddle while the
tensors in the tuples or the lists can have different shapes.
2. Any tensors or operations created outside of ``true_fn`` and
``false_fn`` will be executed regardless of which branch is selected at
runtime. This has frequently surprised users who expected a lazy
semantics. For example:
2. This API could be used under both static mode or dygraph mode. If it
is in dygraph mode, the API only runs one branch based on condition.
3. If it is in static mode, any tensors or operations created outside
or inside of ``true_fn`` and ``false_fn`` will be in net building
regardless of which branch is selected at runtime. This has frequently
surprised users who expected a lazy semantics. For example:
.. code-block:: python
......@@ -2328,9 +2331,11 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
a = paddle.zeros((1, 1))
b = paddle.zeros((1, 1))
c = a * b
out = paddle.nn.cond(a < b, lambda: a + c, lambda: b * b)
out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b)
No matter whether ``a < b`` , ``c = a * b`` will run.
No matter whether ``a < b`` , ``c = a * b`` will be in net building and
run. ``a + c`` and ``b * b`` will be in net building, but only one
branch will be executed during runtime.
Args:
pred(Tensor): A boolean tensor whose numel should be 1. The boolean
......@@ -2366,24 +2371,24 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
# return 3, 2
#
def true_func():
return paddle.fill_constant(shape=[1, 2], dtype='int32',
value=1), paddle.fill_constant(shape=[2, 3],
return paddle.full(shape=[1, 2], dtype='int32',
fill_value=1), paddle.full(shape=[2, 3],
dtype='bool',
value=True)
fill_value=True)
def false_func():
return paddle.fill_constant(shape=[3, 4], dtype='float32',
value=3), paddle.fill_constant(shape=[4, 5],
return paddle.full(shape=[3, 4], dtype='float32',
fill_value=3), paddle.full(shape=[4, 5],
dtype='int64',
value=2)
fill_value=2)
x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1)
y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23)
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
pred = paddle.less_than(x=x, y=y, name=None)
ret = paddle.nn.cond(pred, true_func, false_func)
ret = paddle.static.nn.cond(pred, true_func, false_func)
# ret is a tuple containing 2 tensors
# ret[0] = [[1 1]]
# ret[1] = [[ True True True]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册