diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 013a842e1123dd7e330e7e34c776e1c66026456d..498e7126d67c75056386da44a90ef90fe8416edd 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None): This operator is like a C++ switch/case statement. Args: - branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. + branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors. default(callable, optional): Callable that returns a structure of Tensors. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, + Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``, or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``. Raises: - TypeError: If the type of ``branch_index`` is not Variable. + TypeError: If the type of ``branch_index`` is not Tensor. TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``. TypeError: If the type of ``branch_fns`` is not dict, list or tuple. TypeError: If the elements of ``branch_fns`` is not 2-tuple. @@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None): Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.fluid.layers as layers + import paddle + + paddle.enable_static() def fn_1(): - return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1) def fn_2(): - return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2) def fn_3(): - return layers.fill_constant(shape=[3], dtype='int32', value=3) + return paddle.fill_constant(shape=[3], dtype='int32', value=3) - main_program = fluid.default_startup_program() - startup_program = fluid.default_main_program() - with fluid.program_guard(main_program, startup_program): - index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) - index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + main_program = paddle.static.default_startup_program() + startup_program = paddle.static.default_main_program() + with paddle.static.program_guard(main_program, startup_program): + index_1 = paddle.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = paddle.fill_constant(shape=[1], dtype='int32', value=2) - out_1 = layers.switch_case( + out_1 = paddle.static.nn.switch_case( branch_index=index_1, branch_fns={1: fn_1, 2: fn_2}, default=fn_3) - out_2 = layers.switch_case( + out_2 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)], default=fn_3) # Argument default is None and no index matches. fn_3 will be called because of the max index 7. - out_3 = layers.switch_case( + out_3 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) - exe = fluid.Executor(fluid.CPUPlace()) + exe = paddle.static.Executor(paddle.CPUPlace()) res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) print(res_1) # [[1. 1.]] print(res_2) # [[2 2] [2 2]] diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 47a8668362e5e0b3901cda602b483d3e96bce29a..b79b965f5b9023b09df6dbf905561f192145dbf0 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -41,7 +41,6 @@ from .clip import clip_by_norm #DEFINE_ALIAS from .control_flow import cond #DEFINE_ALIAS # from .control_flow import DynamicRNN #DEFINE_ALIAS # from .control_flow import StaticRNN #DEFINE_ALIAS -from .control_flow import switch_case #DEFINE_ALIAS from .control_flow import while_loop #DEFINE_ALIAS # from .control_flow import rnn #DEFINE_ALIAS # from .decode import BeamSearchDecoder #DEFINE_ALIAS diff --git a/python/paddle/nn/control_flow.py b/python/paddle/nn/control_flow.py index 85f2fbcbe6eccf0052a10fce2960211be2244af4..a78b65c3c6c82ce65c66ce5d43889642beb51d0e 100644 --- a/python/paddle/nn/control_flow.py +++ b/python/paddle/nn/control_flow.py @@ -16,13 +16,10 @@ from ..fluid.layers import cond #DEFINE_ALIAS from ..fluid.layers import while_loop #DEFINE_ALIAS -from ..fluid.layers import switch_case #DEFINE_ALIAS - __all__ = [ 'cond', # 'DynamicRNN', # 'StaticRNN', - 'switch_case', 'while_loop', # 'rnn' ] diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 51d295d050ea8dc1ecf225666888956208a359f4..510e11312f4ce1d037e687b18f79d36b0b8f1104 100644 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -35,6 +35,7 @@ __all__ = [ 'prelu', 'row_conv', 'spectral_norm', + 'switch_case', ] from ...fluid.layers import fc #DEFINE_ALIAS @@ -58,5 +59,6 @@ from ...fluid.layers import nce #DEFINE_ALIAS from ...fluid.layers import prelu #DEFINE_ALIAS from ...fluid.layers import row_conv #DEFINE_ALIAS from ...fluid.layers import spectral_norm #DEFINE_ALIAS +from ...fluid.layers import switch_case #DEFINE_ALIAS from ...fluid.input import embedding #DEFINE_ALIAS