switch_case_cn.rst 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
.. _cn_api_fluid_layers_switch_case:

switch_case
-------------------------------

.. py:function:: paddle.fluid.layers.switch_case(branch_index, branch_fns, default=None, name=None)

该OP的运行方式类似于c++的switch/case。

参数:
    - **branch_index** (Variable)- 形状为[1]的Tensor,指定将要执行的分支。数据类型是 ``int32``, ``int64`` 或 ``uint8``。
    - **branch_fns** (dict|list|tuple) - 如果 ``branch_fns`` 是一个list或tuple,它的元素可以是 (int, callable) 二元组,即由整数和可调用对象构成的二元组,整数表示对应的可调用对象的键;也可以仅仅是可调用对象,它在list或者tuple中的实际索引值将作为该可调用对象的键。如果 ``branch_fns`` 是一个字典,那么它的键是整数,它的值是可调用对象。所有的可调用对象都返回相同结构的Tensor。
    - **default** (callable,可选) - 可调用对象,返回一个或多个张量。
    - **name** (str,可选) – 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值:None。

返回:如果 ``branch_fns`` 中存在与 ``branch_index`` 匹配的可调用对象,则返回该可调用对象的返回结果;如果 ``branch_fns`` 中不存在与 ``branch_index`` 匹配的可调用对象且 ``default`` 不是None,则返回调用 ``default`` 的返回结果;
如果 ``branch_fns`` 中不存在与 ``branch_index`` 匹配的可调用对象且 ``default`` 是None,则返回 ``branch_fns`` 中键值最大的可调用对象的返回结果。

返回类型:Variable|list(Variable)

抛出异常:
    - ``TypeError`` - 如果 ``branch_index`` 的类型不是list或tuple。
    - ``TypeError`` - 如果 ``branch_index`` 的数据类型不是 ``int32``, ``int64`` 或 ``uint8``。
    - ``TypeError`` - 如果 ``branch_fns`` 的类型不是dict,list或tuple。
    - ``TypeError`` - 如果 ``branch_fns`` 的元素不是2-tuple。
    - ``TypeError`` - 如果 ``branch_fns`` 中的2-tuple的第一个元素的类型不是整数。
    - ``ValueError`` - 如果 ``branch_fns`` 中的2-tuple的第一个元素值不唯一。
    - ``TypeError`` - 如果 ``branch_fns`` 中的2-tuple的第二个元素不是可调用对象。
    - ``TypeError`` - 当 ``default`` 不是None又不是可调用对象时。


**代码示例**:

.. code-block:: python

    import paddle.fluid as fluid
    import paddle.fluid.layers as layers

    def fn_1():
        return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)

    def fn_2():
        return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)

    def fn_3():
        return layers.fill_constant(shape=[3], dtype='int32', value=3)

    main_program = fluid.default_startup_program()
    startup_program = fluid.default_main_program()
    with fluid.program_guard(main_program, startup_program):
        index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
        index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)

        out_1 = layers.switch_case(
            branch_index=index_1,
            branch_fns={1: fn_1, 2: fn_2},
            default=fn_3)

        out_2 = layers.switch_case(
            branch_index=index_2,
            branch_fns=[(1, fn_1), (2, fn_2)],
            default=fn_3)

        # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
        out_3 = layers.switch_case(
            branch_index=index_2,
            branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])

        exe = fluid.Executor(fluid.CPUPlace())
        res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
        print(res_1)  # [[1. 1.]]
        print(res_2)  # [[2 2] [2 2]]
        print(res_3)  # [3 3 3]