未验证 提交 db85f4cf 编写于 作者: H hutuxian 提交者: GitHub

Add dygraph implementation for multiplex op (#29049)

上级 b0d1ac16
...@@ -5719,7 +5719,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None): ...@@ -5719,7 +5719,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None):
@templatedoc() @templatedoc()
def multiplex(inputs, index): def multiplex(inputs, index, name=None):
""" """
Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor. Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor.
...@@ -5749,34 +5749,29 @@ def multiplex(inputs, index): ...@@ -5749,34 +5749,29 @@ def multiplex(inputs, index):
Args: Args:
inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2. inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2.
index (Variable): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors. index (Tensor): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable(Tensor): Output of multiplex OP, with data type being float32, float64, int32, int64. Tensor: Output of multiplex OP, with data type being float32, float64, int32, int64.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import numpy as np import numpy as np
x1 = fluid.data(name='x1', shape=[None, 2], dtype='float32')
x2 = fluid.data(name='x2', shape=[None, 2], dtype='float32')
index = fluid.data(name='index', shape=[None, 1], dtype='int32')
out = fluid.layers.multiplex(inputs=[x1, x2], index=index)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32) img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32) img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
index = np.array([[1], [0]]).astype(np.int32) inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)]
index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32))
res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out]) res = paddle.multiplex(inputs, index)
print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)] print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)]
""" """
if in_dygraph_mode():
return core.ops.multiplex(index, inputs)
helper = LayerHelper('multiplex', **locals()) helper = LayerHelper('multiplex', **locals())
check_type(inputs, 'inputs', (list), 'multiplex') check_type(inputs, 'inputs', (list), 'multiplex')
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -91,5 +92,16 @@ class TestMultiplexOpError(unittest.TestCase): ...@@ -91,5 +92,16 @@ class TestMultiplexOpError(unittest.TestCase):
self.assertRaises(TypeError, test_type2) self.assertRaises(TypeError, test_type2)
class TestMultiplexODygrap(unittest.TestCase):
def test_multiplex_dygraph(self):
paddle.disable_static()
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)]
index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32))
res = paddle.multiplex(inputs, index)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册