未验证 提交 3f170dd8 编写于 作者: L liym27 提交者: GitHub

[API 2.0] Fix example code of api 'switch_case' and add/delete alias (#27578)

* Fix example code of api `fluid.layers.switch_case` to use api2.0

* delete `paddle.nn.switch_case` alias and add `paddle.static.nn.switch_case`
上级 c5b6e44b
......@@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
This operator is like a C++ switch/case statement.
Args:
branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.
Raises:
TypeError: If the type of ``branch_index`` is not Variable.
TypeError: If the type of ``branch_index`` is not Tensor.
TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
TypeError: If the elements of ``branch_fns`` is not 2-tuple.
......@@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle
paddle.enable_static()
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
return paddle.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
index_1 = paddle.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = paddle.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]]
......
......@@ -41,7 +41,6 @@ from .clip import clip_by_norm #DEFINE_ALIAS
from .control_flow import cond #DEFINE_ALIAS
# from .control_flow import DynamicRNN #DEFINE_ALIAS
# from .control_flow import StaticRNN #DEFINE_ALIAS
from .control_flow import switch_case #DEFINE_ALIAS
from .control_flow import while_loop #DEFINE_ALIAS
# from .control_flow import rnn #DEFINE_ALIAS
# from .decode import BeamSearchDecoder #DEFINE_ALIAS
......
......@@ -16,13 +16,10 @@
from ..fluid.layers import cond #DEFINE_ALIAS
from ..fluid.layers import while_loop #DEFINE_ALIAS
from ..fluid.layers import switch_case #DEFINE_ALIAS
__all__ = [
'cond',
# 'DynamicRNN',
# 'StaticRNN',
'switch_case',
'while_loop',
# 'rnn'
]
......@@ -35,6 +35,7 @@ __all__ = [
'prelu',
'row_conv',
'spectral_norm',
'switch_case',
]
from ...fluid.layers import fc #DEFINE_ALIAS
......@@ -58,5 +59,6 @@ from ...fluid.layers import nce #DEFINE_ALIAS
from ...fluid.layers import prelu #DEFINE_ALIAS
from ...fluid.layers import row_conv #DEFINE_ALIAS
from ...fluid.layers import spectral_norm #DEFINE_ALIAS
from ...fluid.layers import switch_case #DEFINE_ALIAS
from ...fluid.input import embedding #DEFINE_ALIAS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册