test_multi_dot_op.py 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from numpy.linalg import multi_dot
from op_test import OpTest
import paddle
21
from paddle.fluid.framework import _test_eager_guard
22 23 24 25 26 27 28

paddle.enable_static()


#the unittest of multi_dot
#compare the result of paddle multi_dot and numpy multi_dot
class TestMultiDotOp(OpTest):
29

30 31
    def setUp(self):
        self.op_type = "multi_dot"
32
        self.python_api = paddle.linalg.multi_dot
33 34 35 36 37 38 39 40 41 42 43 44 45
        self.dtype = self.get_dtype()
        self.get_inputs_and_outputs()

    def get_dtype(self):
        return "float64"

    def get_inputs_and_outputs(self):
        self.A = np.random.random((2, 8)).astype(self.dtype)
        self.B = np.random.random((8, 4)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
        self.outputs = {'Out': multi_dot([self.A, self.B])}

    def test_check_output(self):
46
        self.check_output(check_eager=True)
47 48

    def test_check_grad(self):
49 50
        self.check_grad(['x0'], 'Out', check_eager=True)
        self.check_grad(['x1'], 'Out', check_eager=True)
51 52 53 54


#(A*B)*C
class TestMultiDotOp3Mat(TestMultiDotOp):
55

56 57 58 59 60 61 62 63
    def get_inputs_and_outputs(self):
        self.A = np.random.random((2, 10)).astype(self.dtype)
        self.B = np.random.random((10, 4)).astype(self.dtype)
        self.C = np.random.random((4, 3)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}

    def test_check_grad(self):
64 65 66
        self.check_grad(['x0'], 'Out', check_eager=True)
        self.check_grad(['x1'], 'Out', check_eager=True)
        self.check_grad(['x2'], 'Out', check_eager=True)
67 68 69 70


#A*(B*C)
class TestMultiDotOp3Mat2(TestMultiDotOp):
71

72 73 74 75 76 77 78 79
    def get_inputs_and_outputs(self):
        self.A = np.random.random((3, 4)).astype(self.dtype)
        self.B = np.random.random((4, 8)).astype(self.dtype)
        self.C = np.random.random((8, 2)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}

    def test_check_grad(self):
80 81 82
        self.check_grad(['x0'], 'Out', check_eager=True)
        self.check_grad(['x1'], 'Out', check_eager=True)
        self.check_grad(['x2'], 'Out', check_eager=True)
83 84 85


class TestMultiDotOp4Mat(TestMultiDotOp):
86

87 88 89 90 91 92
    def get_inputs_and_outputs(self):
        self.A = np.random.random((8, 6)).astype(self.dtype)
        self.B = np.random.random((6, 3)).astype(self.dtype)
        self.C = np.random.random((3, 4)).astype(self.dtype)
        self.D = np.random.random((4, 5)).astype(self.dtype)
        self.inputs = {
93 94
            'X': [('x0', self.A), ('x1', self.B), ('x2', self.C),
                  ('x3', self.D)]
95 96 97 98
        }
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}

    def test_check_grad(self):
99 100 101 102
        self.check_grad(['x0'], 'Out', check_eager=True)
        self.check_grad(['x1'], 'Out', check_eager=True)
        self.check_grad(['x2'], 'Out', check_eager=True)
        self.check_grad(['x3'], 'Out', check_eager=True)
103 104 105


class TestMultiDotOpFirst1D(TestMultiDotOp):
106

107 108 109 110 111 112 113 114
    def get_inputs_and_outputs(self):
        self.A = np.random.random((4)).astype(self.dtype)
        self.B = np.random.random((4, 3)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
        self.outputs = {'Out': multi_dot([self.A, self.B])}


class TestMultiDotOp3MatFirst1D(TestMultiDotOp3Mat):
115

116 117 118 119 120 121 122 123 124
    def get_inputs_and_outputs(self):
        self.A = np.random.random((4)).astype(self.dtype)
        self.B = np.random.random((4, 3)).astype(self.dtype)
        self.C = np.random.random((3, 3)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}


class TestMultiDotOp4MatFirst1D(TestMultiDotOp4Mat):
125

126 127 128 129 130 131
    def get_inputs_and_outputs(self):
        self.A = np.random.random((4)).astype(self.dtype)
        self.B = np.random.random((4, 3)).astype(self.dtype)
        self.C = np.random.random((3, 4)).astype(self.dtype)
        self.D = np.random.random((4, 5)).astype(self.dtype)
        self.inputs = {
132 133
            'X': [('x0', self.A), ('x1', self.B), ('x2', self.C),
                  ('x3', self.D)]
134 135 136 137 138
        }
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}


class TestMultiDotOpLast1D(TestMultiDotOp):
139

140 141 142 143 144 145 146 147
    def get_inputs_and_outputs(self):
        self.A = np.random.random((3, 6)).astype(self.dtype)
        self.B = np.random.random((6)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
        self.outputs = {'Out': multi_dot([self.A, self.B])}


class TestMultiDotOp3MatLast1D(TestMultiDotOp3Mat):
148

149 150 151 152 153 154 155 156
    def get_inputs_and_outputs(self):
        self.A = np.random.random((2, 4)).astype(self.dtype)
        self.B = np.random.random((4, 3)).astype(self.dtype)
        self.C = np.random.random((3)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}

    def test_check_grad(self):
157 158 159
        self.check_grad(['x0'], 'Out', check_eager=True)
        self.check_grad(['x1'], 'Out', check_eager=True)
        self.check_grad(['x2'], 'Out', check_eager=True)
160 161 162


class TestMultiDotOp4MatLast1D(TestMultiDotOp4Mat):
163

164 165 166 167 168 169
    def get_inputs_and_outputs(self):
        self.A = np.random.random((2, 3)).astype(self.dtype)
        self.B = np.random.random((3, 2)).astype(self.dtype)
        self.C = np.random.random((2, 3)).astype(self.dtype)
        self.D = np.random.random((3)).astype(self.dtype)
        self.inputs = {
170 171
            'X': [('x0', self.A), ('x1', self.B), ('x2', self.C),
                  ('x3', self.D)]
172 173 174 175 176
        }
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}


class TestMultiDotOpFirstAndLast1D(TestMultiDotOp):
177

178 179 180 181 182 183 184 185
    def get_inputs_and_outputs(self):
        self.A = np.random.random((4, )).astype(self.dtype)
        self.B = np.random.random((4)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
        self.outputs = {'Out': multi_dot([self.A, self.B])}


class TestMultiDotOp3MatFirstAndLast1D(TestMultiDotOp3Mat):
186

187 188 189 190 191 192 193 194 195
    def get_inputs_and_outputs(self):
        self.A = np.random.random((6, )).astype(self.dtype)
        self.B = np.random.random((6, 4)).astype(self.dtype)
        self.C = np.random.random((4)).astype(self.dtype)
        self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}


class TestMultiDotOp4MatFirstAndLast1D(TestMultiDotOp4Mat):
196

197 198 199 200 201 202
    def get_inputs_and_outputs(self):
        self.A = np.random.random((3, )).astype(self.dtype)
        self.B = np.random.random((3, 4)).astype(self.dtype)
        self.C = np.random.random((4, 2)).astype(self.dtype)
        self.D = np.random.random((2)).astype(self.dtype)
        self.inputs = {
203 204
            'X': [('x0', self.A), ('x1', self.B), ('x2', self.C),
                  ('x3', self.D)]
205 206 207 208 209 210
        }
        self.outputs = {'Out': multi_dot([self.A, self.B, self.C, self.D])}


#####python API test#######
class TestMultiDotOpError(unittest.TestCase):
211

212 213 214 215 216
    def test_errors(self):
        with paddle.static.program_guard(paddle.static.Program(),
                                         paddle.static.Program()):
            # The inputs type of multi_dot must be list matrix.
            input1 = 12
217 218
            self.assertRaises(TypeError, paddle.linalg.multi_dot,
                              [input1, input1])
219 220

            # The inputs dtype of multi_dot must be float64, float64 or float16.
221 222 223
            input2 = paddle.static.data(name='input2',
                                        shape=[10, 10],
                                        dtype="int32")
224 225
            self.assertRaises(TypeError, paddle.linalg.multi_dot,
                              [input2, input2])
226 227 228

            # the number of tensor must be larger than 1
            x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64")
229
            self.assertRaises(ValueError, paddle.linalg.multi_dot, [x0])
230 231 232 233

            #the first tensor must be 1D or 2D
            x1 = paddle.static.data(name='x1', shape=[3, 2, 3], dtype="float64")
            x2 = paddle.static.data(name='x2', shape=[3, 2], dtype="float64")
234
            self.assertRaises(ValueError, paddle.linalg.multi_dot, [x1, x2])
235 236 237 238

            #the last tensor must be 1D or 2D
            x3 = paddle.static.data(name='x3', shape=[3, 2], dtype="float64")
            x4 = paddle.static.data(name='x4', shape=[3, 2, 2], dtype="float64")
239
            self.assertRaises(ValueError, paddle.linalg.multi_dot, [x3, x4])
240 241 242 243 244

            #the tensor must be 2D, except first and last tensor
            x5 = paddle.static.data(name='x5', shape=[3, 2], dtype="float64")
            x6 = paddle.static.data(name='x6', shape=[2], dtype="float64")
            x7 = paddle.static.data(name='x7', shape=[2, 2], dtype="float64")
245
            self.assertRaises(ValueError, paddle.linalg.multi_dot, [x5, x6, x7])
246 247 248


class APITestMultiDot(unittest.TestCase):
249

250 251 252 253 254
    def test_out(self):
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program()):
            x0 = paddle.static.data(name='x0', shape=[3, 2], dtype="float64")
            x1 = paddle.static.data(name='x1', shape=[2, 3], dtype='float64')
255
            result = paddle.linalg.multi_dot([x0, x1])
256 257 258
            exe = paddle.static.Executor(paddle.CPUPlace())
            data1 = np.random.rand(3, 2).astype("float64")
            data2 = np.random.rand(2, 3).astype("float64")
259
            np_res, = exe.run(feed={
260 261 262
                'x0': data1,
                'x1': data2
            },
263
                              fetch_list=[result])
264 265
            expected_result = np.linalg.multi_dot([data1, data2])

266 267 268 269 270 271 272
        np.testing.assert_allclose(
            np_res,
            expected_result,
            rtol=1e-05,
            atol=1e-05,
            err_msg='two value is            {}\n{}, check diff!'.format(
                np_res, expected_result))
273 274 275 276 277 278 279 280

    def test_dygraph_without_out(self):
        paddle.disable_static()
        device = paddle.CPUPlace()
        input_array1 = np.random.rand(3, 4).astype("float64")
        input_array2 = np.random.rand(4, 3).astype("float64")
        data1 = paddle.to_tensor(input_array1)
        data2 = paddle.to_tensor(input_array2)
281
        out = paddle.linalg.multi_dot([data1, data2])
282
        expected_result = np.linalg.multi_dot([input_array1, input_array2])
283
        np.testing.assert_allclose(expected_result, out.numpy(), rtol=1e-05)
284

285 286 287 288
    def test_dygraph_final_state_api(self):
        with _test_eager_guard():
            self.test_dygraph_without_out()

289 290 291

if __name__ == "__main__":
    unittest.main()