From db85f4cf8f5912eb4f0797569cef7e3bf1b77b7a Mon Sep 17 00:00:00 2001 From: hutuxian Date: Thu, 26 Nov 2020 11:36:53 +0800 Subject: [PATCH] Add dygraph implementation for multiplex op (#29049) --- python/paddle/fluid/layers/nn.py | 31 ++++++++----------- .../tests/unittests/test_multiplex_op.py | 12 +++++++ 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6b1e782239c..9bbec75ba0c 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5719,7 +5719,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None): @templatedoc() -def multiplex(inputs, index): +def multiplex(inputs, index, name=None): """ Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor. @@ -5748,35 +5748,30 @@ def multiplex(inputs, index): Args: - inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2. - index (Variable): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors. - + inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2. + index (Tensor): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Returns: - Variable(Tensor): Output of multiplex OP, with data type being float32, float64, int32, int64. + Tensor: Output of multiplex OP, with data type being float32, float64, int32, int64. Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle import numpy as np - - x1 = fluid.data(name='x1', shape=[None, 2], dtype='float32') - x2 = fluid.data(name='x2', shape=[None, 2], dtype='float32') - index = fluid.data(name='index', shape=[None, 1], dtype='int32') - out = fluid.layers.multiplex(inputs=[x1, x2], index=index) - - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - img1 = np.array([[1, 2], [3, 4]]).astype(np.float32) img2 = np.array([[5, 6], [7, 8]]).astype(np.float32) - index = np.array([[1], [0]]).astype(np.int32) - - res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out]) + inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)] + index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32)) + res = paddle.multiplex(inputs, index) print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)] """ + if in_dygraph_mode(): + return core.ops.multiplex(index, inputs) helper = LayerHelper('multiplex', **locals()) check_type(inputs, 'inputs', (list), 'multiplex') diff --git a/python/paddle/fluid/tests/unittests/test_multiplex_op.py b/python/paddle/fluid/tests/unittests/test_multiplex_op.py index 47c648d44b6..a840586d78d 100644 --- a/python/paddle/fluid/tests/unittests/test_multiplex_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiplex_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid @@ -91,5 +92,16 @@ class TestMultiplexOpError(unittest.TestCase): self.assertRaises(TypeError, test_type2) +class TestMultiplexODygrap(unittest.TestCase): + def test_multiplex_dygraph(self): + paddle.disable_static() + img1 = np.array([[1, 2], [3, 4]]).astype(np.float32) + img2 = np.array([[5, 6], [7, 8]]).astype(np.float32) + inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)] + index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32)) + res = paddle.multiplex(inputs, index) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() -- GitLab