未验证 提交 506c1bb2 编写于 作者: L liym27 提交者: GitHub

[API 2.0: doc] fix doc of api case and delete/add alias (#2688)

* Fix doc and example code of api fluid.layers.case

* delete alias paddle.nn.switch_case and paddle.nn.control_flow.switch_case  

* add paddle.static.nn.switch_case
上级 a596f054
......@@ -457,7 +457,7 @@ paddle.tensor.linalg.cross paddle.cross,paddle.tensor.cross
paddle.fluid.layers.maxout paddle.nn.functional.maxout,paddle.nn.functional.activation.maxout
paddle.nn.layer.norm.InstanceNorm2d paddle.nn.InstanceNorm2d
paddle.fluid.layers.assign paddle.nn.functional.assign,paddle.nn.functional.common.assign
paddle.fluid.layers.case paddle.nn.case,paddle.nn.control_flow.case
paddle.fluid.layers.case paddle.static.nn.case
paddle.fluid.core.CUDAPlace paddle.CUDAPlace,paddle.framework.CUDAPlace
paddle.nn.functional.pooling.max_pool2d paddle.nn.functional.max_pool2d
paddle.fluid.layers.resize_bilinear paddle.nn.functional.resize_bilinear,paddle.nn.functional.vision.resize_bilinear
......
......@@ -6,7 +6,7 @@ case
.. py:function:: paddle.fluid.layers.case(pred_fn_pairs, default=None, name=None)
:api_attr: 声明式编程模式(静态图)
该OP的运行方式类似于python的if-elif-elif-else。
......@@ -19,13 +19,13 @@ case
返回:如果 ``pred_fn_pairs`` 中存在pred是True的元组(pred, fn),则返回第一个为True的pred的元组中fn的返回结果;如果 ``pred_fn_pairs`` 中不存在pred为True的元组(pred, fn) 且 ``default`` 不是None,则返回调用 ``default`` 的返回结果;
如果 ``pred_fn_pairs`` 中不存在pred为True的元组(pred, fn) 且 ``default`` 是None,则返回 ``pred_fn_pairs`` 中最后一个pred的返回结果。
返回类型:Variable|list(Variable)
返回类型:Tensor|list(Tensor)
抛出异常:
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的类型不是list或tuple。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的元素的类型不是tuple。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的tuple类型的元素大小不是2。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 中的2-tuple的第一个元素的类型不是Variable
- ``TypeError`` - 如果 ``pred_fn_pairs`` 中的2-tuple的第一个元素的类型不是Tensor
- ``TypeError`` - 如果 ``pred_fn_pairs`` 中的2-tuple的第二个元素不是可调用对象。
- ``TypeError`` - 当 ``default`` 不是None又不是可调用对象时。
......@@ -33,41 +33,40 @@ case
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle
paddle.enable_static()
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
return paddle.fill_constant(shape=[3], dtype='int32', value=3)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
with paddle.static.program_guard(main_program, startup_program):
x = paddle.fill_constant(shape=[1], dtype='float32', value=0.3)
y = paddle.fill_constant(shape=[1], dtype='float32', value=0.1)
z = paddle.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
print(res_1) # [[1. 1.]]
print(res_2) # [3 3 3]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册