test_matmul_v2_op.py 21.8 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

S
ShenLiang 已提交
17
import numpy as np
18
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
S
ShenLiang 已提交
19 20 21

import paddle
import paddle.fluid as fluid
22
import paddle.fluid.core as core
23
from paddle.fluid.framework import _test_eager_guard
24
from paddle.fluid.tests.unittests.testsuite import create_op
S
ShenLiang 已提交
25 26 27 28 29 30 31 32


def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
    """Reference forward implementation using np.matmul."""
    # np.matmul does not support the transpose flags, so we manually
    # transpose X and Y appropriately.
    if transpose_X:
        if X.ndim == 1:
33
            X = X.reshape((X.size,))
S
ShenLiang 已提交
34 35 36 37 38 39 40 41
        elif X.ndim == 2:
            X = X.T
        else:
            dim = [i for i in range(len(X.shape))]
            dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
            X = np.transpose(X, tuple(dim))
    if transpose_Y:
        if Y.ndim == 1:
42
            Y = Y.reshape((Y.size,))
S
ShenLiang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        else:
            dim = [i for i in range(len(Y.shape))]
            dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
            Y = np.transpose(Y, tuple(dim))

    Out = np.matmul(X, Y)
    if not Out.shape:
        # We do not support 0-dimensional Tensors (scalars). So where
        # np.matmul outputs a scalar, we must convert to a Tensor of
        # shape (1, ) instead.
        # Everywhere else, we are compatible with np.matmul.
        Out = np.array([Out], dtype="float64")
    return Out


class TestMatMulV2Op(OpTest):
    """
    case 1
    """

    def config(self):
64 65
        self.x_shape = (100,)
        self.y_shape = (100,)
S
ShenLiang 已提交
66 67
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
68 69

    def init_kernel_type(self):
70
        self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
S
ShenLiang 已提交
71 72

    def setUp(self):
S
ShenLiang 已提交
73
        self.init_kernel_type()
S
ShenLiang 已提交
74 75
        self.config()
        self.op_type = "matmul_v2"
76 77 78 79 80 81 82 83 84
        if self.is_bfloat16_op():
            x = np.random.random(self.x_shape).astype(np.float32)
            y = np.random.random(self.y_shape).astype(np.float32)
        else:
            x = np.random.random(self.x_shape).astype(self.dtype)
            y = np.random.random(self.y_shape).astype(self.dtype)
            # -0.1 ~ 0.1
            x = -0.1 + 0.2 * x
            y = -0.1 + 0.2 * y
S
ShenLiang 已提交
85
        result = reference_matmul(x, y, self.trans_x, self.trans_y)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        if self.is_bfloat16_op():
            result = result.astype(np.float32)
            self.inputs = {
                'X': convert_float_to_uint16(x),
                'Y': convert_float_to_uint16(y),
            }
            self.inputs_fp32 = {
                'X': x,
                'Y': y,
            }
        else:
            result = result.astype(self.dtype)
            self.inputs = {
                'X': x,
                'Y': y,
            }
S
ShenLiang 已提交
102 103 104 105
        self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
        self.outputs = {'Out': result}

    def test_check_output(self):
106
        self.check_output(check_eager=False)
S
ShenLiang 已提交
107 108

    def test_check_grad(self):
109
        if core.is_compiled_with_rocm():
110 111 112
            self.check_grad(
                ['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=False
            )
113
        else:
114
            self.check_grad(['X', 'Y'], 'Out', check_eager=False)
S
ShenLiang 已提交
115 116


117
class TestMatMulOp2(TestMatMulV2Op):
S
ShenLiang 已提交
118 119 120 121 122
    """
    case 2
    """

    def config(self):
123
        self.x_shape = (100,)
S
ShenLiang 已提交
124 125 126 127 128
        self.y_shape = (1, 3, 2, 100)
        self.trans_x = False
        self.trans_y = True


129
class TestMatMulOp3(TestMatMulV2Op):
S
ShenLiang 已提交
130 131 132 133 134
    """
    case 3
    """

    def config(self):
135
        self.x_shape = (100,)
S
ShenLiang 已提交
136 137 138 139 140
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


141
class TestMatMulOp4(TestMatMulV2Op):
S
ShenLiang 已提交
142 143 144 145 146
    """
    case 4
    """

    def config(self):
147
        self.x_shape = (100,)
S
ShenLiang 已提交
148 149 150 151 152
        self.y_shape = (1, 2, 100, 2)
        self.trans_x = False
        self.trans_y = False


153
class TestMatMulOp5(TestMatMulV2Op):
S
ShenLiang 已提交
154 155 156 157 158
    """
    case 5
    """

    def config(self):
S
ShenLiang 已提交
159
        self.x_shape = (1, 1, 100, 1)
160
        self.y_shape = (100,)
S
ShenLiang 已提交
161 162 163 164
        self.trans_x = True
        self.trans_y = False


165
class TestMatMulOp6(TestMatMulV2Op):
S
ShenLiang 已提交
166 167 168 169 170
    """
    case 6
    """

    def config(self):
171
        self.x_shape = (1, 2, 102, 1)
172
        self.y_shape = (102,)
S
ShenLiang 已提交
173 174 175 176
        self.trans_x = True
        self.trans_y = False


177
class TestMatMulOp7(TestMatMulV2Op):
S
ShenLiang 已提交
178 179 180 181 182 183
    """
    case 7
    """

    def config(self):
        self.x_shape = (1, 2, 1, 100)
184
        self.y_shape = (100,)
S
ShenLiang 已提交
185 186 187 188
        self.trans_x = False
        self.trans_y = False


189
class TestMatMulOp8(TestMatMulV2Op):
S
ShenLiang 已提交
190 191 192 193 194 195 196 197 198 199 200
    """
    case 8
    """

    def config(self):
        self.x_shape = (1, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


201
class TestMatMulOp9(TestMatMulV2Op):
S
ShenLiang 已提交
202 203 204 205 206 207 208 209 210 211 212
    """
    case 9
    """

    def config(self):
        self.x_shape = (1, 1, 1, 100)
        self.y_shape = (2, 1, 2, 100)
        self.trans_x = False
        self.trans_y = True


213
class TestMatMulOp10(TestMatMulV2Op):
S
ShenLiang 已提交
214 215 216 217 218
    """
    case 10
    """

    def config(self):
S
ShenLiang 已提交
219 220
        self.x_shape = (1, 1, 25, 4)
        self.y_shape = (1, 2, 4, 25)
S
ShenLiang 已提交
221 222 223 224
        self.trans_x = False
        self.trans_y = False


225
class TestMatMulOp11(TestMatMulV2Op):
S
ShenLiang 已提交
226 227 228 229 230 231 232 233 234 235 236
    """
    case 11
    """

    def config(self):
        self.x_shape = (2, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


237
class TestMatMulOp12(TestMatMulV2Op):
S
ShenLiang 已提交
238 239 240 241 242
    """
    case 12
    """

    def config(self):
S
ShenLiang 已提交
243 244
        self.x_shape = (2, 1, 4, 25)
        self.y_shape = (1, 1, 4, 25)
S
ShenLiang 已提交
245 246 247 248
        self.trans_x = True
        self.trans_y = False


249
class TestMatMulOp13(TestMatMulV2Op):
S
ShenLiang 已提交
250 251 252 253 254
    """
    case 13
    """

    def config(self):
S
ShenLiang 已提交
255 256
        self.x_shape = (2, 2, 10, 10)
        self.y_shape = (2, 2, 10, 10)
S
ShenLiang 已提交
257 258 259 260
        self.trans_x = True
        self.trans_y = False


261
class TestMatMulOp14(TestMatMulV2Op):
S
ShenLiang 已提交
262 263 264 265 266
    """
    case 14_1
    """

    def config(self):
267 268
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
269 270 271 272
        self.trans_x = True
        self.trans_y = False


273
class TestMatMulOp15(TestMatMulV2Op):
S
ShenLiang 已提交
274 275 276 277 278
    """
    case 14_2
    """

    def config(self):
279 280
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
281 282 283 284
        self.trans_x = False
        self.trans_y = False


285
class TestMatMulOp16(TestMatMulV2Op):
S
ShenLiang 已提交
286 287 288 289 290
    """
    case 16 : to check the gradient for special case
    """

    def config(self):
291
        self.x_shape = 100
S
ShenLiang 已提交
292
        self.y_shape = (1, 2, 2, 100, 2)
S
ShenLiang 已提交
293 294 295 296
        self.trans_x = False
        self.trans_y = False


297
class TestMatMulOp17(TestMatMulV2Op):
S
ShenLiang 已提交
298 299 300 301 302 303
    """
    case 17 : to check the gradient for special case
    """

    def config(self):
        self.x_shape = (2, 1, 100)
304
        self.y_shape = 100
S
ShenLiang 已提交
305 306
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
307 308


309
class TestMatMulOpBroadcast1(TestMatMulV2Op):
310 311 312 313 314 315 316 317 318 319 320
    """
    case 14_3
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = True
        self.trans_y = True


321
class TestMatMulOpBroadcast2(TestMatMulV2Op):
322 323 324 325 326 327 328 329 330 331 332
    """
    case 14_4
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = False
        self.trans_y = True


333
# --------------------test matmul fp16--------------------
S
ShenLiang 已提交
334 335 336


def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
337 338 339
    @unittest.skipIf(
        not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
    )
S
ShenLiang 已提交
340 341 342 343 344 345 346 347
    class TestMatMulOpFp16Case(parent):
        def init_kernel_type(self):
            self.dtype = np.float16

        def test_check_output(self):
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
348 349 350
                    self.check_output_with_place(
                        place, atol=atol, check_eager=False
                    )
S
ShenLiang 已提交
351 352 353 354 355

        def test_check_grad(self):
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                self.check_grad_with_place(
356 357
                    place,
                    ['X', 'Y'],
S
ShenLiang 已提交
358
                    'Out',
359
                    max_relative_error=max_relative_error,
360 361
                    check_eager=False,
                )
S
ShenLiang 已提交
362 363 364 365 366 367 368

    cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
    TestMatMulOpFp16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpFp16Case


create_test_fp16_class(TestMatMulV2Op)
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)

386
# --------------------test matmul bf16--------------------
387 388 389 390


def create_test_bf16_class(parent, atol=0.01):
    @unittest.skipIf(
391 392
        not core.is_compiled_with_cuda()
        or not core.is_bfloat16_supported(core.CUDAPlace(0)),
393 394
        "core is not compiled with CUDA and not support the bfloat16",
    )
395 396 397 398
    class TestMatMulOpBf16Case(parent):
        def get_numeric_grad(self, place, check_name):
            scope = core.Scope()
            self._check_grad_helper()
399 400 401 402 403 404
            op = create_op(
                scope, self.op_type, self.inputs, self.outputs, self.attrs
            )
            return get_numeric_gradient(
                place, scope, op, self.inputs_fp32, check_name, ['Out']
            )
405 406 407 408 409 410 411 412 413 414 415

        def init_kernel_type(self):
            self.dtype = np.uint16

        def test_check_output(self):
            place = core.CUDAPlace(0)
            self.check_output_with_place(place, atol=atol)

        def test_check_grad_x(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'X')
416 417 418 419 420 421 422
            self.check_grad_with_place(
                place,
                ['X'],
                'Out',
                no_grad_set=set(['Y']),
                user_defined_grads=[numeric_grads],
            )
423 424 425 426

        def test_check_grad_y(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'Y')
427 428 429 430 431 432 433
            self.check_grad_with_place(
                place,
                ['Y'],
                'Out',
                no_grad_set=set(['X']),
                user_defined_grads=[numeric_grads],
            )
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459

        def test_check_grad(self):
            pass

    cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
    TestMatMulOpBf16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpBf16Case


create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
S
ShenLiang 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478


class TestMatMulV2API(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
            input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")

            result = paddle.matmul(input_x, input_y)

            x_np = np.random.random([4, 3]).astype("float32")
            y_np = np.random.random([3, 4]).astype("float32")

            exe = fluid.Executor(place)
479 480 481 482 483
            fetches = exe.run(
                fluid.default_main_program(),
                feed={"input_x": x_np, "input_y": y_np},
                fetch_list=[result],
            )
S
ShenLiang 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                input_x = np.random.random([4, 3]).astype("float64")
                input_y = np.random.random([3, 4]).astype("float64")
                x = paddle.to_tensor(input_x)
                y = paddle.to_tensor(input_y)
                result = paddle.matmul(x, y)

S
ShenLiang 已提交
498 499 500 501 502 503 504 505 506 507 508
    def test_dygraph_fp16(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
                    input_x = np.random.random([4, 3]).astype("float16")
                    input_y = np.random.random([3, 4]).astype("float16")
                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)

509 510 511 512 513
    def test_compute_type_fp32(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
514
                    paddle.set_flags(
515 516
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
517 518 519 520 521 522 523 524 525 526 527 528 529
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertTrue(paddle.isfinite(result)[0, 0, 0])
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
530
                    np.testing.assert_array_equal(result_np, result.numpy())
531
                    paddle.set_flags(
532 533
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
534 535 536 537 538 539

    def test_compute_type_fp16_nan(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
540
                    paddle.set_flags(
541 542
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
543 544 545 546 547 548 549 550 551 552 553 554
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertFalse(
555 556
                        paddle.isfinite(result)[0, 0, 0]
                    )  # contains nan/inf
557
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
558
                    paddle.set_flags(
559 560
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
561

562 563 564 565 566
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_dygraph()
            self.test_dygraph_fp16()

S
ShenLiang 已提交
567

C
chentianyu03 已提交
568 569 570 571 572 573 574 575 576
class TestComplexMatMulOp(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
577
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
578 579 580 581 582 583 584 585
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
586 587 588 589 590 591
        self.x = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
592 593 594
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
595 596 597
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
598 599 600 601
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)

    def test_check_output(self):
602
        self.check_output(check_eager=False)
C
chentianyu03 已提交
603 604

    def test_check_grad_normal(self):
605 606 607 608 609 610 611
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
612 613

    def test_check_grad_ingore_x(self):
614 615 616 617 618 619 620 621
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
622 623

    def test_check_grad_ingore_y(self):
624 625 626 627 628 629 630 631
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
632 633 634 635 636 637 638 639 640 641 642


class TestComplexMatMulOpBroadcast(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
643
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
644 645 646 647 648 649 650 651
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
652 653 654 655 656 657
        self.x = np.random.random((10, 2, 5)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 2, 5)).astype(self.dtype)
        self.y = np.random.random((5, 20)).astype(
            self.dtype
        ) + 1j * np.random.random((5, 20)).astype(self.dtype)
C
chentianyu03 已提交
658 659 660
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
661 662 663
        self.grad_out = np.ones((10, 2, 20), self.dtype) + 1j * np.ones(
            (10, 2, 20), self.dtype
        )
C
chentianyu03 已提交
664
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
665 666 667
        self.grad_y = np.sum(
            np.matmul(np.conj(self.x).transpose(0, 2, 1), self.grad_out), axis=0
        )
C
chentianyu03 已提交
668 669

    def test_check_output(self):
670
        self.check_output(check_eager=False)
C
chentianyu03 已提交
671 672

    def test_check_grad_normal(self):
673 674 675 676 677 678 679
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
680 681

    def test_check_grad_ingore_x(self):
682 683 684 685 686 687 688 689
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
690 691

    def test_check_grad_ingore_y(self):
692 693 694 695 696 697 698 699
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
700 701


C
chentianyu03 已提交
702 703 704
class TestMatMulTypePromotion(TestComplexMatMulOp):
    def init_input_output(self):
        self.x = np.random.random((10, 10)).astype(self.dtype)
705 706 707
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
708 709 710
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
711 712 713
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
714 715 716 717
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)


718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
class TestMatmulop(unittest.TestCase):
    def func_dygraph_matmul(self):
        paddle.disable_static()

        np_a = np.random.random((2, 4)).astype(np.float32)
        np_b = np.random.random((4, 2)).astype(np.float32)

        tensor_a = paddle.to_tensor(np_a, dtype="float32")
        tensor_b = paddle.to_tensor(np_b, dtype="float32")

        # normal case: tensor @ nparray
        expect_out = np_a @ np_b
        actual_out = tensor_a @ np_b
        np.testing.assert_allclose(actual_out, expect_out)

        paddle.enable_static()

735
    def func_dygraph_matmul(self):  # noqa: F811
736 737 738 739
        with _test_eager_guard():
            self.func_dygraph_matmul()


S
ShenLiang 已提交
740
if __name__ == "__main__":
C
chentianyu03 已提交
741
    paddle.enable_static()
S
ShenLiang 已提交
742
    unittest.main()