From 3f170dd83da487b7fc9e13dd9145ddf56a9a0fc4 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sun, 27 Sep 2020 18:45:25 +0800 Subject: [PATCH] [API 2.0] Fix example code of api 'switch_case' and add/delete alias (#27578) * Fix example code of api `fluid.layers.switch_case` to use api2.0 * delete `paddle.nn.switch_case` alias and add `paddle.static.nn.switch_case` --- python/paddle/fluid/layers/control_flow.py | 35 +++++++++++----------- python/paddle/nn/__init__.py | 1 - python/paddle/nn/control_flow.py | 3 -- python/paddle/static/nn/__init__.py | 2 ++ 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 013a842e112..498e7126d67 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None): This operator is like a C++ switch/case statement. Args: - branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. + branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors. default(callable, optional): Callable that returns a structure of Tensors. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, + Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``, or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``. Raises: - TypeError: If the type of ``branch_index`` is not Variable. + TypeError: If the type of ``branch_index`` is not Tensor. TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``. TypeError: If the type of ``branch_fns`` is not dict, list or tuple. TypeError: If the elements of ``branch_fns`` is not 2-tuple. @@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None): Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.fluid.layers as layers + import paddle + + paddle.enable_static() def fn_1(): - return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1) def fn_2(): - return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2) def fn_3(): - return layers.fill_constant(shape=[3], dtype='int32', value=3) + return paddle.fill_constant(shape=[3], dtype='int32', value=3) - main_program = fluid.default_startup_program() - startup_program = fluid.default_main_program() - with fluid.program_guard(main_program, startup_program): - index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) - index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + main_program = paddle.static.default_startup_program() + startup_program = paddle.static.default_main_program() + with paddle.static.program_guard(main_program, startup_program): + index_1 = paddle.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = paddle.fill_constant(shape=[1], dtype='int32', value=2) - out_1 = layers.switch_case( + out_1 = paddle.static.nn.switch_case( branch_index=index_1, branch_fns={1: fn_1, 2: fn_2}, default=fn_3) - out_2 = layers.switch_case( + out_2 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)], default=fn_3) # Argument default is None and no index matches. fn_3 will be called because of the max index 7. - out_3 = layers.switch_case( + out_3 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) - exe = fluid.Executor(fluid.CPUPlace()) + exe = paddle.static.Executor(paddle.CPUPlace()) res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) print(res_1) # [[1. 1.]] print(res_2) # [[2 2] [2 2]] diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 47a8668362e..b79b965f5b9 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -41,7 +41,6 @@ from .clip import clip_by_norm #DEFINE_ALIAS from .control_flow import cond #DEFINE_ALIAS # from .control_flow import DynamicRNN #DEFINE_ALIAS # from .control_flow import StaticRNN #DEFINE_ALIAS -from .control_flow import switch_case #DEFINE_ALIAS from .control_flow import while_loop #DEFINE_ALIAS # from .control_flow import rnn #DEFINE_ALIAS # from .decode import BeamSearchDecoder #DEFINE_ALIAS diff --git a/python/paddle/nn/control_flow.py b/python/paddle/nn/control_flow.py index 85f2fbcbe6e..a78b65c3c6c 100644 --- a/python/paddle/nn/control_flow.py +++ b/python/paddle/nn/control_flow.py @@ -16,13 +16,10 @@ from ..fluid.layers import cond #DEFINE_ALIAS from ..fluid.layers import while_loop #DEFINE_ALIAS -from ..fluid.layers import switch_case #DEFINE_ALIAS - __all__ = [ 'cond', # 'DynamicRNN', # 'StaticRNN', - 'switch_case', 'while_loop', # 'rnn' ] diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 51d295d050e..510e11312f4 100644 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -35,6 +35,7 @@ __all__ = [ 'prelu', 'row_conv', 'spectral_norm', + 'switch_case', ] from ...fluid.layers import fc #DEFINE_ALIAS @@ -58,5 +59,6 @@ from ...fluid.layers import nce #DEFINE_ALIAS from ...fluid.layers import prelu #DEFINE_ALIAS from ...fluid.layers import row_conv #DEFINE_ALIAS from ...fluid.layers import spectral_norm #DEFINE_ALIAS +from ...fluid.layers import switch_case #DEFINE_ALIAS from ...fluid.input import embedding #DEFINE_ALIAS -- GitLab