未验证 提交 c9f7cff0 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add a new op: paddle.linalg.multi_dot (#35224)

上级 72b07726
此差异已折叠。
......@@ -99,6 +99,7 @@ from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import bmm # noqa: F401
from .tensor.linalg import histogram # noqa: F401
from .tensor.linalg import mv # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from numpy.linalg import multi_dot
from op_test import OpTest
import paddle
paddle.enable_static()
#the unittest of multi_dot
#compare the result of paddle multi_dot and numpy multi_dot
class TestMultiDotOp(OpTest):
def setUp(self):
self.op_type = "multi_dot"
self.dtype = self.get_dtype()
self.get_inputs_and_outputs()
def get_dtype(self):
return "float64"
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 8)).astype(self.dtype)
self.B = np.random.random((8, 4)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
self.outputs = {'Out': multi_dot([self.A, self.B])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
#(A*B)*C
class TestMultiDotOp3Mat(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 10)).astype(self.dtype)
self.B = np.random.random((10, 4)).astype(self.dtype)
self.C = np.random.random((4, 3)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
#A*(B*C)
class TestMultiDotOp3Mat2(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((3, 4)).astype(self.dtype)
self.B = np.random.random((4, 8)).astype(self.dtype)
self.C = np.random.random((8, 2)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
class TestMultiDotOp4Mat(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((8, 6)).astype(self.dtype)
self.B = np.random.random((6, 3)).astype(self.dtype)
self.C = np.random.random((3, 4)).astype(self.dtype)
self.D = np.random.random((4, 5)).astype(self.dtype)
self.inputs = {
'X':
[('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)]
}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
self.check_grad(['x3'], 'Out')
class TestMultiDotOpFirst1D(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((4)).astype(self.dtype)
self.B = np.random.random((4, 3)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
self.outputs = {'Out': multi_dot([self.A, self.B])}
class TestMultiDotOp3MatFirst1D(TestMultiDotOp3Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((4)).astype(self.dtype)
self.B = np.random.random((4, 3)).astype(self.dtype)
self.C = np.random.random((3, 3)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
class TestMultiDotOp4MatFirst1D(TestMultiDotOp4Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((4)).astype(self.dtype)
self.B = np.random.random((4, 3)).astype(self.dtype)
self.C = np.random.random((3, 4)).astype(self.dtype)
self.D = np.random.random((4, 5)).astype(self.dtype)
self.inputs = {
'X':
[('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)]
}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
class TestMultiDotOpLast1D(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((3, 6)).astype(self.dtype)
self.B = np.random.random((6)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
self.outputs = {'Out': multi_dot([self.A, self.B])}
class TestMultiDotOp3MatLast1D(TestMultiDotOp3Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 4)).astype(self.dtype)
self.B = np.random.random((4, 3)).astype(self.dtype)
self.C = np.random.random((3)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
class TestMultiDotOp4MatLast1D(TestMultiDotOp4Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 3)).astype(self.dtype)
self.B = np.random.random((3, 2)).astype(self.dtype)
self.C = np.random.random((2, 3)).astype(self.dtype)
self.D = np.random.random((3)).astype(self.dtype)
self.inputs = {
'X':
[('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)]
}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
class TestMultiDotOpFirstAndLast1D(TestMultiDotOp):
def get_inputs_and_outputs(self):
self.A = np.random.random((4, )).astype(self.dtype)
self.B = np.random.random((4)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
self.outputs = {'Out': multi_dot([self.A, self.B])}
class TestMultiDotOp3MatFirstAndLast1D(TestMultiDotOp3Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((6, )).astype(self.dtype)
self.B = np.random.random((6, 4)).astype(self.dtype)
self.C = np.random.random((4)).astype(self.dtype)
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
class TestMultiDotOp4MatFirstAndLast1D(TestMultiDotOp4Mat):
def get_inputs_and_outputs(self):
self.A = np.random.random((3, )).astype(self.dtype)
self.B = np.random.random((3, 4)).astype(self.dtype)
self.C = np.random.random((4, 2)).astype(self.dtype)
self.D = np.random.random((2)).astype(self.dtype)
self.inputs = {
'X':
[('x0', self.A), ('x1', self.B), ('x2', self.C), ('x3', self.D)]
}
self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}
#####python API test#######
class TestMultiDotOpError(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
# The inputs type of multi_dot must be list matrix.
input1 = 12
self.assertRaises(TypeError, paddle.multi_dot, [input1, input1])
# The inputs dtype of multi_dot must be float64, float64 or float16.
input2 = paddle.static.data(
name='input2', shape=[10, 10], dtype="int32")
self.assertRaises(TypeError, paddle.multi_dot, [input2, input2])
# the number of tensor must be larger than 1
x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64")
self.assertRaises(ValueError, paddle.multi_dot, [x0])
#the first tensor must be 1D or 2D
x1 = paddle.static.data(name='x1', shape=[3, 2, 3], dtype="float64")
x2 = paddle.static.data(name='x2', shape=[3, 2], dtype="float64")
self.assertRaises(ValueError, paddle.multi_dot, [x1, x2])
#the last tensor must be 1D or 2D
x3 = paddle.static.data(name='x3', shape=[3, 2], dtype="float64")
x4 = paddle.static.data(name='x4', shape=[3, 2, 2], dtype="float64")
self.assertRaises(ValueError, paddle.multi_dot, [x3, x4])
#the tensor must be 2D, except first and last tensor
x5 = paddle.static.data(name='x5', shape=[3, 2], dtype="float64")
x6 = paddle.static.data(name='x6', shape=[2], dtype="float64")
x7 = paddle.static.data(name='x7', shape=[2, 2], dtype="float64")
self.assertRaises(ValueError, paddle.multi_dot, [x5, x6, x7])
class APITestMultiDot(unittest.TestCase):
def test_out(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64")
x1 = paddle.static.data(name='x1', shape=[2, 3], dtype='float64')
result = paddle.multi_dot([x0, x1])
exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.random.rand(3, 2).astype("float64")
data2 = np.random.rand(2, 3).astype("float64")
np_res = exe.run(feed={'x0': data1,
'x1': data2},
fetch_list=[result])
expected_result = np.linalg.multi_dot([data1, data2])
self.assertTrue(
np.allclose(
np_res, expected_result, atol=1e-5),
"two value is\
{}\n{}, check diff!".format(np_res, expected_result))
def test_dygraph_without_out(self):
paddle.disable_static()
device = paddle.CPUPlace()
input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64")
data1 = paddle.to_tensor(input_array1)
data2 = paddle.to_tensor(input_array2)
out = paddle.multi_dot([data1, data2])
expected_result = np.linalg.multi_dot([input_array1, input_array2])
self.assertTrue(np.allclose(expected_result, out.numpy()))
if __name__ == "__main__":
unittest.main()
......@@ -28,4 +28,5 @@ NEED_TO_FIX_OP_LIST = [
'cvm',
'cudnn_lstm',
'rnn',
'multi_dot',
]
......@@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor import inverse as inv # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_rank
from .tensor.linalg import svd
......@@ -23,6 +24,7 @@ __all__ = [
'cholesky', #noqa
'norm',
'inv',
'multi_dot',
'matrix_rank',
'svd',
'matrix_power'
......
......@@ -45,6 +45,8 @@ from .linalg import bmm # noqa: F401
from .linalg import histogram # noqa: F401
from .linalg import mv # noqa: F401
from .linalg import matrix_power # noqa: F401
from .linalg import multi_dot # noqa: F401
from .linalg import svd # noqa: F401
from .logic import equal # noqa: F401
from .logic import greater_equal # noqa: F401
from .logic import greater_than # noqa: F401
......
......@@ -1171,3 +1171,83 @@ def matrix_power(x, n, name=None):
outputs={'Out': out},
attrs={'n': n})
return out
def multi_dot(x, name=None):
"""
Multi_dot is an operator that calculates multiple matrix multiplications.
Supports inputs of float, double and float16 dtypes. This function does not
support batched inputs.
The input tensor in [x] must be 2-D except for the first and last can be 1-D.
If the first tensor is a 1-D vector of shape(n, ) it is treated as row vector
of shape(1, n), similarly if the last tensor is a 1D vector of shape(n, ), it
is treated as a column vector of shape(n, 1).
If the first and last tensor are 2-D matrix, then the output is also 2-D matrix,
otherwise the output is a 1-D vector.
Multi_dot will select the lowest cost multiplication order for calculation. The
cost of multiplying two matrices with shapes (a, b) and (b, c) is a * b * c.
Given matrices A, B, C with shapes (20, 5), (5, 100), (100, 10) respectively,
we can calculate the cost of different multiplication orders as follows:
- Cost((AB)C) = 20x5x100 + 20x100x10 = 30000
- Cost(A(BC)) = 5x100x10 + 20x5x10 = 6000
In this case, multiplying B and C first, then multiply A, which is 5 times faster
than sequential calculation.
Args:
x ([Tensor]): The input tensors which is a list Tensor.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Tensor: The output Tensor.
Examples:
.. code-block:: python
import paddle
import numpy as np
# A * B
A_data = np.random.random([3, 4]).astype(np.float32)
B_data = np.random.random([4, 5]).astype(np.float32)
A = paddle.to_tensor(A_data)
B = paddle.to_tensor(B_data)
out = paddle.multi_dot([A, B])
print(out.numpy().shape)
# [3, 5]
# A * B * C
A_data = np.random.random([10, 5]).astype(np.float32)
B_data = np.random.random([5, 8]).astype(np.float32)
C_data = np.random.random([8, 7]).astype(np.float32)
A = paddle.to_tensor(A_data)
B = paddle.to_tensor(B_data)
C = paddle.to_tensor(C_data)
out = paddle.multi_dot([A, B, C])
print(out.numpy().shape)
# [10, 7]
"""
if in_dygraph_mode():
return _C_ops.multi_dot(x)
check_type(x, 'x', (list, tuple), 'multi_dot')
for id, item in enumerate(x):
check_variable_and_dtype(item, 'x[' + str(id) + ']',
['float16', 'float32', 'float64'], 'multi_dot')
if item.dtype != x[0].dtype:
raise TypeError(
"All the Tensors in the input must have the same data type.")
helper = LayerHelper('multi_dot', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type='multi_dot', inputs={"X": x}, outputs={"Out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册