未验证 提交 db85f4cf 编写于 作者: H hutuxian 提交者: GitHub

Add dygraph implementation for multiplex op (#29049)

上级 b0d1ac16
......@@ -5719,7 +5719,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None):
@templatedoc()
def multiplex(inputs, index):
def multiplex(inputs, index, name=None):
"""
Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor.
......@@ -5748,35 +5748,30 @@ def multiplex(inputs, index):
Args:
inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2.
index (Variable): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.
inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2.
index (Tensor): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Variable(Tensor): Output of multiplex OP, with data type being float32, float64, int32, int64.
Tensor: Output of multiplex OP, with data type being float32, float64, int32, int64.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
x1 = fluid.data(name='x1', shape=[None, 2], dtype='float32')
x2 = fluid.data(name='x2', shape=[None, 2], dtype='float32')
index = fluid.data(name='index', shape=[None, 1], dtype='int32')
out = fluid.layers.multiplex(inputs=[x1, x2], index=index)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
index = np.array([[1], [0]]).astype(np.int32)
res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out])
inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)]
index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32))
res = paddle.multiplex(inputs, index)
print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)]
"""
if in_dygraph_mode():
return core.ops.multiplex(index, inputs)
helper = LayerHelper('multiplex', **locals())
check_type(inputs, 'inputs', (list), 'multiplex')
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
......@@ -91,5 +92,16 @@ class TestMultiplexOpError(unittest.TestCase):
self.assertRaises(TypeError, test_type2)
class TestMultiplexODygrap(unittest.TestCase):
def test_multiplex_dygraph(self):
paddle.disable_static()
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
inputs = [paddle.to_tensor(img1), paddle.to_tensor(img2)]
index = paddle.to_tensor(np.array([[1], [0]]).astype(np.int32))
res = paddle.multiplex(inputs, index)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册