“e6c01375c3ebdf169a55b260b0540bbc6f9d5e2f”上不存在“doc/Compile_CN.md”
未验证 提交 3f170dd8 编写于 作者: L liym27 提交者: GitHub

[API 2.0] Fix example code of api 'switch_case' and add/delete alias (#27578)

* Fix example code of api `fluid.layers.switch_case` to use api2.0

* delete `paddle.nn.switch_case` alias and add `paddle.static.nn.switch_case`
上级 c5b6e44b
...@@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None): ...@@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
This operator is like a C++ switch/case statement. This operator is like a C++ switch/case statement.
Args: Args:
branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors. branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors. default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``, or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``. or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.
Raises: Raises:
TypeError: If the type of ``branch_index`` is not Variable. TypeError: If the type of ``branch_index`` is not Tensor.
TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``. TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
TypeError: If the type of ``branch_fns`` is not dict, list or tuple. TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
TypeError: If the elements of ``branch_fns`` is not 2-tuple. TypeError: If the elements of ``branch_fns`` is not 2-tuple.
...@@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None): ...@@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import paddle.fluid.layers as layers
paddle.enable_static()
def fn_1(): def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2(): def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3(): def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3) return paddle.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program() main_program = paddle.static.default_startup_program()
startup_program = fluid.default_main_program() startup_program = paddle.static.default_main_program()
with fluid.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) index_1 = paddle.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) index_2 = paddle.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case( out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2}, branch_fns={1: fn_1, 2: fn_2},
default=fn_3) default=fn_3)
out_2 = layers.switch_case( out_2 = paddle.static.nn.switch_case(
branch_index=index_2, branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)], branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3) default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7. # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers.switch_case( out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = fluid.Executor(fluid.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]] print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]] print(res_2) # [[2 2] [2 2]]
......
...@@ -41,7 +41,6 @@ from .clip import clip_by_norm #DEFINE_ALIAS ...@@ -41,7 +41,6 @@ from .clip import clip_by_norm #DEFINE_ALIAS
from .control_flow import cond #DEFINE_ALIAS from .control_flow import cond #DEFINE_ALIAS
# from .control_flow import DynamicRNN #DEFINE_ALIAS # from .control_flow import DynamicRNN #DEFINE_ALIAS
# from .control_flow import StaticRNN #DEFINE_ALIAS # from .control_flow import StaticRNN #DEFINE_ALIAS
from .control_flow import switch_case #DEFINE_ALIAS
from .control_flow import while_loop #DEFINE_ALIAS from .control_flow import while_loop #DEFINE_ALIAS
# from .control_flow import rnn #DEFINE_ALIAS # from .control_flow import rnn #DEFINE_ALIAS
# from .decode import BeamSearchDecoder #DEFINE_ALIAS # from .decode import BeamSearchDecoder #DEFINE_ALIAS
......
...@@ -16,13 +16,10 @@ ...@@ -16,13 +16,10 @@
from ..fluid.layers import cond #DEFINE_ALIAS from ..fluid.layers import cond #DEFINE_ALIAS
from ..fluid.layers import while_loop #DEFINE_ALIAS from ..fluid.layers import while_loop #DEFINE_ALIAS
from ..fluid.layers import switch_case #DEFINE_ALIAS
__all__ = [ __all__ = [
'cond', 'cond',
# 'DynamicRNN', # 'DynamicRNN',
# 'StaticRNN', # 'StaticRNN',
'switch_case',
'while_loop', 'while_loop',
# 'rnn' # 'rnn'
] ]
...@@ -35,6 +35,7 @@ __all__ = [ ...@@ -35,6 +35,7 @@ __all__ = [
'prelu', 'prelu',
'row_conv', 'row_conv',
'spectral_norm', 'spectral_norm',
'switch_case',
] ]
from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import fc #DEFINE_ALIAS
...@@ -58,5 +59,6 @@ from ...fluid.layers import nce #DEFINE_ALIAS ...@@ -58,5 +59,6 @@ from ...fluid.layers import nce #DEFINE_ALIAS
from ...fluid.layers import prelu #DEFINE_ALIAS from ...fluid.layers import prelu #DEFINE_ALIAS
from ...fluid.layers import row_conv #DEFINE_ALIAS from ...fluid.layers import row_conv #DEFINE_ALIAS
from ...fluid.layers import spectral_norm #DEFINE_ALIAS from ...fluid.layers import spectral_norm #DEFINE_ALIAS
from ...fluid.layers import switch_case #DEFINE_ALIAS
from ...fluid.input import embedding #DEFINE_ALIAS from ...fluid.input import embedding #DEFINE_ALIAS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册