test_matmul_v2_op.py 22.4 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

S
ShenLiang 已提交
17
import numpy as np
18 19
from eager_op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from testsuite import create_op
S
ShenLiang 已提交
20 21

import paddle
22 23
from paddle import fluid
from paddle.fluid import core
S
ShenLiang 已提交
24 25 26 27 28 29 30 31


def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
    """Reference forward implementation using np.matmul."""
    # np.matmul does not support the transpose flags, so we manually
    # transpose X and Y appropriately.
    if transpose_X:
        if X.ndim == 1:
32
            X = X.reshape((X.size,))
S
ShenLiang 已提交
33 34 35
        elif X.ndim == 2:
            X = X.T
        else:
36
            dim = list(range(len(X.shape)))
S
ShenLiang 已提交
37 38 39 40
            dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
            X = np.transpose(X, tuple(dim))
    if transpose_Y:
        if Y.ndim == 1:
41
            Y = Y.reshape((Y.size,))
S
ShenLiang 已提交
42
        else:
43
            dim = list(range(len(Y.shape)))
S
ShenLiang 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56
            dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
            Y = np.transpose(Y, tuple(dim))

    Out = np.matmul(X, Y)
    return Out


class TestMatMulV2Op(OpTest):
    """
    case 1
    """

    def config(self):
57 58
        self.x_shape = (100,)
        self.y_shape = (100,)
S
ShenLiang 已提交
59 60
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
61 62

    def init_kernel_type(self):
63
        self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
S
ShenLiang 已提交
64 65

    def setUp(self):
S
ShenLiang 已提交
66
        self.init_kernel_type()
S
ShenLiang 已提交
67 68
        self.config()
        self.op_type = "matmul_v2"
69
        self.python_api = paddle.tensor.matmul
70 71 72 73 74 75 76 77 78
        if self.is_bfloat16_op():
            x = np.random.random(self.x_shape).astype(np.float32)
            y = np.random.random(self.y_shape).astype(np.float32)
        else:
            x = np.random.random(self.x_shape).astype(self.dtype)
            y = np.random.random(self.y_shape).astype(self.dtype)
            # -0.1 ~ 0.1
            x = -0.1 + 0.2 * x
            y = -0.1 + 0.2 * y
S
ShenLiang 已提交
79
        result = reference_matmul(x, y, self.trans_x, self.trans_y)
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        if self.is_bfloat16_op():
            result = result.astype(np.float32)
            self.inputs = {
                'X': convert_float_to_uint16(x),
                'Y': convert_float_to_uint16(y),
            }
            self.inputs_fp32 = {
                'X': x,
                'Y': y,
            }
        else:
            result = result.astype(self.dtype)
            self.inputs = {
                'X': x,
                'Y': y,
            }
S
ShenLiang 已提交
96 97 98 99
        self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
        self.outputs = {'Out': result}

    def test_check_output(self):
100 101 102
        self.check_output(
            check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True
        )
S
ShenLiang 已提交
103 104

    def test_check_grad(self):
105
        if core.is_compiled_with_rocm():
106 107 108 109 110 111 112 113
            self.check_grad(
                ['X', 'Y'],
                'Out',
                max_relative_error=1e-2,
                check_cinn=self.check_cinn
                if hasattr(self, 'check_cinn')
                else True,
            )
114
        else:
115 116 117 118 119 120 121
            self.check_grad(
                ['X', 'Y'],
                'Out',
                check_cinn=self.check_cinn
                if hasattr(self, 'check_cinn')
                else True,
            )
S
ShenLiang 已提交
122 123


124
class TestMatMulOp2(TestMatMulV2Op):
S
ShenLiang 已提交
125 126 127 128 129
    """
    case 2
    """

    def config(self):
130
        self.x_shape = (100,)
S
ShenLiang 已提交
131 132 133 134 135
        self.y_shape = (1, 3, 2, 100)
        self.trans_x = False
        self.trans_y = True


136
class TestMatMulOp3(TestMatMulV2Op):
S
ShenLiang 已提交
137 138 139 140 141
    """
    case 3
    """

    def config(self):
142
        self.x_shape = (100,)
S
ShenLiang 已提交
143 144 145 146 147
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


148
class TestMatMulOp4(TestMatMulV2Op):
S
ShenLiang 已提交
149 150 151 152 153
    """
    case 4
    """

    def config(self):
154
        self.x_shape = (100,)
S
ShenLiang 已提交
155 156 157 158 159
        self.y_shape = (1, 2, 100, 2)
        self.trans_x = False
        self.trans_y = False


160
class TestMatMulOp5(TestMatMulV2Op):
S
ShenLiang 已提交
161 162 163 164 165
    """
    case 5
    """

    def config(self):
S
ShenLiang 已提交
166
        self.x_shape = (1, 1, 100, 1)
167
        self.y_shape = (100,)
S
ShenLiang 已提交
168 169 170 171
        self.trans_x = True
        self.trans_y = False


172
class TestMatMulOp6(TestMatMulV2Op):
S
ShenLiang 已提交
173 174 175 176 177
    """
    case 6
    """

    def config(self):
178
        self.x_shape = (1, 2, 102, 1)
179
        self.y_shape = (102,)
S
ShenLiang 已提交
180 181 182 183
        self.trans_x = True
        self.trans_y = False


184
class TestMatMulOp7(TestMatMulV2Op):
S
ShenLiang 已提交
185 186 187 188 189 190
    """
    case 7
    """

    def config(self):
        self.x_shape = (1, 2, 1, 100)
191
        self.y_shape = (100,)
S
ShenLiang 已提交
192 193 194 195
        self.trans_x = False
        self.trans_y = False


196
class TestMatMulOp8(TestMatMulV2Op):
S
ShenLiang 已提交
197 198 199 200 201 202 203 204 205 206 207
    """
    case 8
    """

    def config(self):
        self.x_shape = (1, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


208
class TestMatMulOp9(TestMatMulV2Op):
S
ShenLiang 已提交
209 210 211 212 213 214 215 216 217 218 219
    """
    case 9
    """

    def config(self):
        self.x_shape = (1, 1, 1, 100)
        self.y_shape = (2, 1, 2, 100)
        self.trans_x = False
        self.trans_y = True


220
class TestMatMulOp10(TestMatMulV2Op):
S
ShenLiang 已提交
221 222 223 224 225
    """
    case 10
    """

    def config(self):
S
ShenLiang 已提交
226 227
        self.x_shape = (1, 1, 25, 4)
        self.y_shape = (1, 2, 4, 25)
S
ShenLiang 已提交
228 229 230 231
        self.trans_x = False
        self.trans_y = False


232
class TestMatMulOp11(TestMatMulV2Op):
S
ShenLiang 已提交
233 234 235 236 237 238 239 240 241 242 243
    """
    case 11
    """

    def config(self):
        self.x_shape = (2, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


244
class TestMatMulOp12(TestMatMulV2Op):
S
ShenLiang 已提交
245 246 247 248 249
    """
    case 12
    """

    def config(self):
S
ShenLiang 已提交
250 251
        self.x_shape = (2, 1, 4, 25)
        self.y_shape = (1, 1, 4, 25)
S
ShenLiang 已提交
252 253 254 255
        self.trans_x = True
        self.trans_y = False


256
class TestMatMulOp13(TestMatMulV2Op):
S
ShenLiang 已提交
257 258 259 260 261
    """
    case 13
    """

    def config(self):
S
ShenLiang 已提交
262 263
        self.x_shape = (2, 2, 10, 10)
        self.y_shape = (2, 2, 10, 10)
S
ShenLiang 已提交
264 265 266 267
        self.trans_x = True
        self.trans_y = False


268
class TestMatMulOp14(TestMatMulV2Op):
S
ShenLiang 已提交
269 270 271 272 273
    """
    case 14_1
    """

    def config(self):
274 275
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
276 277 278 279
        self.trans_x = True
        self.trans_y = False


280
class TestMatMulOp15(TestMatMulV2Op):
S
ShenLiang 已提交
281 282 283 284 285
    """
    case 14_2
    """

    def config(self):
286 287
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
288 289 290 291
        self.trans_x = False
        self.trans_y = False


292
class TestMatMulOp16(TestMatMulV2Op):
S
ShenLiang 已提交
293 294 295 296 297
    """
    case 16 : to check the gradient for special case
    """

    def config(self):
298
        self.x_shape = 100
S
ShenLiang 已提交
299
        self.y_shape = (1, 2, 2, 100, 2)
S
ShenLiang 已提交
300 301
        self.trans_x = False
        self.trans_y = False
302
        self.check_cinn = False
S
ShenLiang 已提交
303 304


305
class TestMatMulOp17(TestMatMulV2Op):
S
ShenLiang 已提交
306 307 308 309 310 311
    """
    case 17 : to check the gradient for special case
    """

    def config(self):
        self.x_shape = (2, 1, 100)
312
        self.y_shape = 100
S
ShenLiang 已提交
313 314
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
315 316


317
class TestMatMulOpBroadcast1(TestMatMulV2Op):
318 319 320 321 322 323 324 325 326 327 328
    """
    case 14_3
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = True
        self.trans_y = True


329
class TestMatMulOpBroadcast2(TestMatMulV2Op):
330 331 332 333 334 335 336 337 338 339 340
    """
    case 14_4
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = False
        self.trans_y = True


341
# --------------------test matmul fp16--------------------
S
ShenLiang 已提交
342 343 344


def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
345 346 347
    @unittest.skipIf(
        not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
    )
S
ShenLiang 已提交
348 349 350 351 352 353 354 355
    class TestMatMulOpFp16Case(parent):
        def init_kernel_type(self):
            self.dtype = np.float16

        def test_check_output(self):
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
356 357 358 359 360 361 362
                    self.check_output_with_place(
                        place,
                        atol=atol,
                        check_cinn=self.check_cinn
                        if hasattr(self, 'check_cinn')
                        else True,
                    )
S
ShenLiang 已提交
363 364 365 366 367

        def test_check_grad(self):
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                self.check_grad_with_place(
368 369
                    place,
                    ['X', 'Y'],
S
ShenLiang 已提交
370
                    'Out',
371
                    max_relative_error=max_relative_error,
372 373 374
                    check_cinn=self.check_cinn
                    if hasattr(self, 'check_cinn')
                    else True,
375
                )
S
ShenLiang 已提交
376

377
    cls_name = "{}_{}".format(parent.__name__, "Fp16")
S
ShenLiang 已提交
378 379 380 381 382
    TestMatMulOpFp16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpFp16Case


create_test_fp16_class(TestMatMulV2Op)
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)
399 400
create_test_fp16_class(TestMatMulOpBroadcast1)
create_test_fp16_class(TestMatMulOpBroadcast2)
401

402
# --------------------test matmul bf16--------------------
403 404 405 406


def create_test_bf16_class(parent, atol=0.01):
    @unittest.skipIf(
407 408
        not core.is_compiled_with_cuda()
        or not core.is_bfloat16_supported(core.CUDAPlace(0)),
409 410
        "core is not compiled with CUDA and not support the bfloat16",
    )
411 412 413 414
    class TestMatMulOpBf16Case(parent):
        def get_numeric_grad(self, place, check_name):
            scope = core.Scope()
            self._check_grad_helper()
415 416 417 418 419 420
            op = create_op(
                scope, self.op_type, self.inputs, self.outputs, self.attrs
            )
            return get_numeric_gradient(
                place, scope, op, self.inputs_fp32, check_name, ['Out']
            )
421 422 423 424 425 426

        def init_kernel_type(self):
            self.dtype = np.uint16

        def test_check_output(self):
            place = core.CUDAPlace(0)
427 428 429 430 431 432 433
            self.check_output_with_place(
                place,
                atol=atol,
                check_cinn=self.check_cinn
                if hasattr(self, 'check_cinn')
                else True,
            )
434 435 436 437

        def test_check_grad_x(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'X')
438 439 440 441
            self.check_grad_with_place(
                place,
                ['X'],
                'Out',
442
                no_grad_set={'Y'},
443
                user_defined_grads=[numeric_grads],
444 445 446
                check_cinn=self.check_cinn
                if hasattr(self, 'check_cinn')
                else True,
447
            )
448 449 450 451

        def test_check_grad_y(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'Y')
452 453 454 455
            self.check_grad_with_place(
                place,
                ['Y'],
                'Out',
456
                no_grad_set={'X'},
457
                user_defined_grads=[numeric_grads],
458 459 460
                check_cinn=self.check_cinn
                if hasattr(self, 'check_cinn')
                else True,
461
            )
462 463 464 465

        def test_check_grad(self):
            pass

466
    cls_name = "{}_{}".format(parent.__name__, "Bf16")
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
    TestMatMulOpBf16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpBf16Case


create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
S
ShenLiang 已提交
488 489 490 491 492 493 494 495 496 497


class TestMatMulV2API(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
498 499 500 501 502 503
            input_x = paddle.static.data(
                name="input_x", shape=[4, 3], dtype="float32"
            )
            input_y = paddle.static.data(
                name="input_y", shape=[3, 4], dtype="float32"
            )
S
ShenLiang 已提交
504 505 506 507 508 509 510

            result = paddle.matmul(input_x, input_y)

            x_np = np.random.random([4, 3]).astype("float32")
            y_np = np.random.random([3, 4]).astype("float32")

            exe = fluid.Executor(place)
511 512 513 514 515
            fetches = exe.run(
                fluid.default_main_program(),
                feed={"input_x": x_np, "input_y": y_np},
                fetch_list=[result],
            )
S
ShenLiang 已提交
516 517 518 519 520 521 522 523 524 525 526 527 528 529

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                input_x = np.random.random([4, 3]).astype("float64")
                input_y = np.random.random([3, 4]).astype("float64")
                x = paddle.to_tensor(input_x)
                y = paddle.to_tensor(input_y)
                result = paddle.matmul(x, y)

S
ShenLiang 已提交
530 531 532 533 534 535 536 537 538 539 540
    def test_dygraph_fp16(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
                    input_x = np.random.random([4, 3]).astype("float16")
                    input_y = np.random.random([3, 4]).astype("float16")
                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)

541 542 543 544 545
    def test_compute_type_fp32(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
546
                    paddle.set_flags(
547 548
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
549 550 551 552 553 554 555 556 557 558 559 560 561
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertTrue(paddle.isfinite(result)[0, 0, 0])
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
562
                    np.testing.assert_array_equal(result_np, result.numpy())
563
                    paddle.set_flags(
564 565
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
566 567 568 569 570 571

    def test_compute_type_fp16_nan(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
572
                    paddle.set_flags(
573 574
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
575 576 577 578 579 580 581 582 583 584 585 586
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertFalse(
587 588
                        paddle.isfinite(result)[0, 0, 0]
                    )  # contains nan/inf
589
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
590
                    paddle.set_flags(
591 592
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
593

S
ShenLiang 已提交
594

C
chentianyu03 已提交
595 596 597
class TestComplexMatMulOp(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
598
        self.python_api = paddle.tensor.matmul
C
chentianyu03 已提交
599 600 601 602 603 604
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
605
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
606 607 608 609 610 611 612 613
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
614 615 616 617 618 619
        self.x = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
620 621 622
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
623 624 625
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
626 627 628 629
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)

    def test_check_output(self):
630
        self.check_output(check_cinn=False)
C
chentianyu03 已提交
631 632

    def test_check_grad_normal(self):
633 634 635 636 637
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
638
            check_cinn=False,
639
        )
C
chentianyu03 已提交
640 641

    def test_check_grad_ingore_x(self):
642 643 644 645 646 647
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
648
            check_cinn=False,
649
        )
C
chentianyu03 已提交
650 651

    def test_check_grad_ingore_y(self):
652 653 654 655 656 657
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
658
            check_cinn=False,
659
        )
C
chentianyu03 已提交
660 661 662 663 664


class TestComplexMatMulOpBroadcast(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
665
        self.python_api = paddle.tensor.matmul
C
chentianyu03 已提交
666 667 668 669 670 671
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
672
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
673 674 675 676 677 678 679 680
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
681 682 683 684 685 686
        self.x = np.random.random((10, 2, 5)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 2, 5)).astype(self.dtype)
        self.y = np.random.random((5, 20)).astype(
            self.dtype
        ) + 1j * np.random.random((5, 20)).astype(self.dtype)
C
chentianyu03 已提交
687 688 689
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
690 691 692
        self.grad_out = np.ones((10, 2, 20), self.dtype) + 1j * np.ones(
            (10, 2, 20), self.dtype
        )
C
chentianyu03 已提交
693
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
694 695 696
        self.grad_y = np.sum(
            np.matmul(np.conj(self.x).transpose(0, 2, 1), self.grad_out), axis=0
        )
C
chentianyu03 已提交
697 698

    def test_check_output(self):
699
        self.check_output(check_cinn=False)
C
chentianyu03 已提交
700 701

    def test_check_grad_normal(self):
702 703 704 705 706
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
707
            check_cinn=False,
708
        )
C
chentianyu03 已提交
709 710

    def test_check_grad_ingore_x(self):
711 712 713 714 715 716
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
717
            check_cinn=False,
718
        )
C
chentianyu03 已提交
719 720

    def test_check_grad_ingore_y(self):
721 722 723 724 725 726
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
727
            check_cinn=False,
728
        )
C
chentianyu03 已提交
729 730


C
chentianyu03 已提交
731 732 733
class TestMatMulTypePromotion(TestComplexMatMulOp):
    def init_input_output(self):
        self.x = np.random.random((10, 10)).astype(self.dtype)
734 735 736
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
737 738 739
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
740 741 742
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
743 744 745 746
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)


747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
class TestMatmulop(unittest.TestCase):
    def func_dygraph_matmul(self):
        paddle.disable_static()

        np_a = np.random.random((2, 4)).astype(np.float32)
        np_b = np.random.random((4, 2)).astype(np.float32)

        tensor_a = paddle.to_tensor(np_a, dtype="float32")
        tensor_b = paddle.to_tensor(np_b, dtype="float32")

        # normal case: tensor @ nparray
        expect_out = np_a @ np_b
        actual_out = tensor_a @ np_b
        np.testing.assert_allclose(actual_out, expect_out)

        paddle.enable_static()


S
ShenLiang 已提交
765
if __name__ == "__main__":
C
chentianyu03 已提交
766
    paddle.enable_static()
S
ShenLiang 已提交
767
    unittest.main()