Switch_cn.rst 2.6 KB
Newer Older
H
Hao Wang 已提交
1 2 3 4 5 6 7
.. _cn_api_fluid_layers_Switch:

Switch
-------------------------------

.. py:class:: paddle.fluid.layers.Switch (name=None)

P
Pei Yang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
该类用于实现Switch分支控制功能。Switch分支包含多个case分支和一个default分支,Switch控制流会依次检查各case分支条件是否满足,并仅执行第一个满足条件的case分支后面的语句。若不存在满足条件的case分支,则仅执行default分支后面的语句。

成员函数:
    - **case(cond)** - Switch的case分支,其参数cond为bool型的标量Variable。只有当前case分支的cond为True,且之前的case分支的cond均为False,该case分支后的语句才会执行,且不再执行之后的case后的语句。
    - **default()** - Switch的default分支。当所有case分支的cond均为False时,执行default分支后的语句。

注意:case和default函数只能用于Switch的scope内部,示例如下:

..  code-block:: python

    with fluid.layers.Switch() as switch:
        with switch.case(cond1):
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
        with switch.case(cond2):
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
        with switch.default():
            i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)

参数:
P
Pei Yang 已提交
27
    - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值为None。
H
Hao Wang 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

**代码示例**

..  code-block:: python

    import paddle.fluid as fluid

    lr = fluid.layers.create_global_var(
        shape=[1],
        value=0.0,
        dtype='float32',
        persistable=True,
        name="learning_rate")
    zero_var = fluid.layers.fill_constant(
        shape=[1], dtype='float32', value=0.0)
    one_var = fluid.layers.fill_constant(
        shape=[1], dtype='float32', value=1.0)
    two_var = fluid.layers.fill_constant(
        shape=[1], dtype='float32', value=2.0)
P
Pei Yang 已提交
47 48 49
    
    # 将参数中的begin设为非0值,则进入Switch的default分支,输出数组中的数字将为2
    global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) 
H
Hao Wang 已提交
50 51 52 53 54 55 56

    with fluid.layers.control_flow.Switch() as switch:
        with switch.case(global_step == zero_var):
            fluid.layers.assign(input=one_var, output=lr)
        with switch.default():
            fluid.layers.assign(input=two_var, output=lr)

P
Pei Yang 已提交
57 58
    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())
H
Hao Wang 已提交
59

P
Pei Yang 已提交
60 61
    res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
    print(res) # [array([1.], dtype=float32)]
H
Hao Wang 已提交
62 63