未验证 提交 aceb68f4 编写于 作者: L liym27 提交者: GitHub

add Chinese document of control flow API case and switch_case (#1629)

* add Chinese document of control flow API case and switch_case. test=develop

* polish TypeError according to reviews. test=develop
上级 3abf97c4
...@@ -36,6 +36,7 @@ fluid.layers ...@@ -36,6 +36,7 @@ fluid.layers
layers/box_decoder_and_assign.rst layers/box_decoder_and_assign.rst
layers/bpr_loss.rst layers/bpr_loss.rst
layers/brelu.rst layers/brelu.rst
layers/case.rst
layers/cast.rst layers/cast.rst
layers/Categorical.rst layers/Categorical.rst
layers/ceil.rst layers/ceil.rst
...@@ -290,6 +291,7 @@ fluid.layers ...@@ -290,6 +291,7 @@ fluid.layers
layers/sums.rst layers/sums.rst
layers/swish.rst layers/swish.rst
layers/Switch.rst layers/Switch.rst
layers/switch_case.rst
layers/tanh.rst layers/tanh.rst
layers/tanh_shrink.rst layers/tanh_shrink.rst
layers/target_assign.rst layers/target_assign.rst
......
...@@ -41,6 +41,7 @@ fluid.layers ...@@ -41,6 +41,7 @@ fluid.layers
layers_cn/bpr_loss_cn.rst layers_cn/bpr_loss_cn.rst
layers_cn/brelu_cn.rst layers_cn/brelu_cn.rst
layers_cn/BeamSearchDecoder_cn.rst layers_cn/BeamSearchDecoder_cn.rst
layers_cn/case_cn.rst
layers_cn/cast_cn.rst layers_cn/cast_cn.rst
layers_cn/Categorical_cn.rst layers_cn/Categorical_cn.rst
layers_cn/ceil_cn.rst layers_cn/ceil_cn.rst
...@@ -296,6 +297,7 @@ fluid.layers ...@@ -296,6 +297,7 @@ fluid.layers
layers_cn/sums_cn.rst layers_cn/sums_cn.rst
layers_cn/swish_cn.rst layers_cn/swish_cn.rst
layers_cn/Switch_cn.rst layers_cn/Switch_cn.rst
layers_cn/switch_case_cn.rst
layers_cn/tanh_cn.rst layers_cn/tanh_cn.rst
layers_cn/tanh_shrink_cn.rst layers_cn/tanh_shrink_cn.rst
layers_cn/target_assign_cn.rst layers_cn/target_assign_cn.rst
......
.. _cn_api_fluid_layers_case:
case
-------------------------------
.. py:function:: paddle.fluid.layers.case(pred_fn_pairs, default=None, name=None)
该OP的运行方式类似于python的if-elif-elif-else。
参数:
- **pred_fn_pairs** (list|tuple) - 一个list或者tuple,元素是二元组(pred, fn)。其中 ``pred`` 是形状为[1]的布尔型 Tensor,``fn`` 是一个可调用对象。所有的可调用对象都返回相同结构的Tensor。
- **default** (callable,可选) - 可调用对象,返回一个或多个张量。
- **name** (str,可选) – 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值:None。
返回:如果 ``pred_fn_pairs`` 中存在pred是True的元组(pred, fn),则返回第一个为True的pred的元组中fn的返回结果;如果 ``pred_fn_pairs`` 中不存在pred为True的元组(pred, fn) 且 ``default`` 不是None,则返回调用 ``default`` 的返回结果;
如果 ``pred_fn_pairs`` 中不存在pred为True的元组(pred, fn) 且 ``default`` 是None,则返回 ``pred_fn_pairs`` 中最后一个pred的返回结果。
返回类型:Variable|list(Variable)
抛出异常:
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的类型不是list或tuple。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的元素的类型不是tuple。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 的tuple类型的元素大小不是2。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 中的2-tuple的第一个元素的类型不是Variable。
- ``TypeError`` - 如果 ``pred_fn_pairs`` 中的2-tuple的第二个元素不是可调用对象。
- ``TypeError`` - 当 ``default`` 不是None又不是可调用对象时。
**代码示例**:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
print(res_1) # [[1. 1.]]
print(res_2) # [3 3 3]
.. _cn_api_fluid_layers_switch_case:
switch_case
-------------------------------
.. py:function:: paddle.fluid.layers.switch_case(branch_index, branch_fns, default=None, name=None)
该OP的运行方式类似于c++的switch/case。
参数:
- **branch_index** (Variable)- 形状为[1]的Tensor,指定将要执行的分支。数据类型是 ``int32``, ``int64`` 或 ``uint8``。
- **branch_fns** (dict|list|tuple) - 如果 ``branch_fns`` 是一个list或tuple,它的元素可以是 (int, callable) 二元组,即由整数和可调用对象构成的二元组,整数表示对应的可调用对象的键;也可以仅仅是可调用对象,它在list或者tuple中的实际索引值将作为该可调用对象的键。如果 ``branch_fns`` 是一个字典,那么它的键是整数,它的值是可调用对象。所有的可调用对象都返回相同结构的Tensor。
- **default** (callable,可选) - 可调用对象,返回一个或多个张量。
- **name** (str,可选) – 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值:None。
返回:如果 ``branch_fns`` 中存在与 ``branch_index`` 匹配的可调用对象,则返回该可调用对象的返回结果;如果 ``branch_fns`` 中不存在与 ``branch_index`` 匹配的可调用对象且 ``default`` 不是None,则返回调用 ``default`` 的返回结果;
如果 ``branch_fns`` 中不存在与 ``branch_index`` 匹配的可调用对象且 ``default`` 是None,则返回 ``branch_fns`` 中键值最大的可调用对象的返回结果。
返回类型:Variable|list(Variable)
抛出异常:
- ``TypeError`` - 如果 ``branch_index`` 的类型不是list或tuple。
- ``TypeError`` - 如果 ``branch_index`` 的数据类型不是 ``int32``, ``int64`` 或 ``uint8``。
- ``TypeError`` - 如果 ``branch_fns`` 的类型不是dict,list或tuple。
- ``TypeError`` - 如果 ``branch_fns`` 的元素不是2-tuple。
- ``TypeError`` - 如果 ``branch_fns`` 中的2-tuple的第一个元素的类型不是整数。
- ``ValueError`` - 如果 ``branch_fns`` 中的2-tuple的第一个元素值不唯一。
- ``TypeError`` - 如果 ``branch_fns`` 中的2-tuple的第二个元素不是可调用对象。
- ``TypeError`` - 当 ``default`` 不是None又不是可调用对象时。
**代码示例**:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]]
print(res_3) # [3 3 3]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册