test_matmul_v2_op.py 21.8 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
17 18
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from paddle.fluid.tests.unittests.testsuite import create_op
S
ShenLiang 已提交
19 20 21 22
import paddle.fluid.core as core

import paddle
import paddle.fluid as fluid
23
from paddle.fluid.framework import _test_eager_guard
S
ShenLiang 已提交
24 25 26 27 28 29 30 31


def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
    """Reference forward implementation using np.matmul."""
    # np.matmul does not support the transpose flags, so we manually
    # transpose X and Y appropriately.
    if transpose_X:
        if X.ndim == 1:
32
            X = X.reshape((X.size,))
S
ShenLiang 已提交
33 34 35 36 37 38 39 40
        elif X.ndim == 2:
            X = X.T
        else:
            dim = [i for i in range(len(X.shape))]
            dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
            X = np.transpose(X, tuple(dim))
    if transpose_Y:
        if Y.ndim == 1:
41
            Y = Y.reshape((Y.size,))
S
ShenLiang 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        else:
            dim = [i for i in range(len(Y.shape))]
            dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
            Y = np.transpose(Y, tuple(dim))

    Out = np.matmul(X, Y)
    if not Out.shape:
        # We do not support 0-dimensional Tensors (scalars). So where
        # np.matmul outputs a scalar, we must convert to a Tensor of
        # shape (1, ) instead.
        # Everywhere else, we are compatible with np.matmul.
        Out = np.array([Out], dtype="float64")
    return Out


class TestMatMulV2Op(OpTest):
    """
    case 1
    """

    def config(self):
63 64
        self.x_shape = (100,)
        self.y_shape = (100,)
S
ShenLiang 已提交
65 66
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
67 68

    def init_kernel_type(self):
69
        self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
S
ShenLiang 已提交
70 71

    def setUp(self):
S
ShenLiang 已提交
72
        self.init_kernel_type()
S
ShenLiang 已提交
73 74
        self.config()
        self.op_type = "matmul_v2"
75 76 77 78 79 80 81 82 83
        if self.is_bfloat16_op():
            x = np.random.random(self.x_shape).astype(np.float32)
            y = np.random.random(self.y_shape).astype(np.float32)
        else:
            x = np.random.random(self.x_shape).astype(self.dtype)
            y = np.random.random(self.y_shape).astype(self.dtype)
            # -0.1 ~ 0.1
            x = -0.1 + 0.2 * x
            y = -0.1 + 0.2 * y
S
ShenLiang 已提交
84
        result = reference_matmul(x, y, self.trans_x, self.trans_y)
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        if self.is_bfloat16_op():
            result = result.astype(np.float32)
            self.inputs = {
                'X': convert_float_to_uint16(x),
                'Y': convert_float_to_uint16(y),
            }
            self.inputs_fp32 = {
                'X': x,
                'Y': y,
            }
        else:
            result = result.astype(self.dtype)
            self.inputs = {
                'X': x,
                'Y': y,
            }
S
ShenLiang 已提交
101 102 103 104
        self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
        self.outputs = {'Out': result}

    def test_check_output(self):
105
        self.check_output(check_eager=False)
S
ShenLiang 已提交
106 107

    def test_check_grad(self):
108
        if core.is_compiled_with_rocm():
109 110 111
            self.check_grad(
                ['X', 'Y'], 'Out', max_relative_error=1e-2, check_eager=False
            )
112
        else:
113
            self.check_grad(['X', 'Y'], 'Out', check_eager=False)
S
ShenLiang 已提交
114 115


116
class TestMatMulOp2(TestMatMulV2Op):
S
ShenLiang 已提交
117 118 119 120 121
    """
    case 2
    """

    def config(self):
122
        self.x_shape = (100,)
S
ShenLiang 已提交
123 124 125 126 127
        self.y_shape = (1, 3, 2, 100)
        self.trans_x = False
        self.trans_y = True


128
class TestMatMulOp3(TestMatMulV2Op):
S
ShenLiang 已提交
129 130 131 132 133
    """
    case 3
    """

    def config(self):
134
        self.x_shape = (100,)
S
ShenLiang 已提交
135 136 137 138 139
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


140
class TestMatMulOp4(TestMatMulV2Op):
S
ShenLiang 已提交
141 142 143 144 145
    """
    case 4
    """

    def config(self):
146
        self.x_shape = (100,)
S
ShenLiang 已提交
147 148 149 150 151
        self.y_shape = (1, 2, 100, 2)
        self.trans_x = False
        self.trans_y = False


152
class TestMatMulOp5(TestMatMulV2Op):
S
ShenLiang 已提交
153 154 155 156 157
    """
    case 5
    """

    def config(self):
S
ShenLiang 已提交
158
        self.x_shape = (1, 1, 100, 1)
159
        self.y_shape = (100,)
S
ShenLiang 已提交
160 161 162 163
        self.trans_x = True
        self.trans_y = False


164
class TestMatMulOp6(TestMatMulV2Op):
S
ShenLiang 已提交
165 166 167 168 169
    """
    case 6
    """

    def config(self):
170
        self.x_shape = (1, 2, 102, 1)
171
        self.y_shape = (102,)
S
ShenLiang 已提交
172 173 174 175
        self.trans_x = True
        self.trans_y = False


176
class TestMatMulOp7(TestMatMulV2Op):
S
ShenLiang 已提交
177 178 179 180 181 182
    """
    case 7
    """

    def config(self):
        self.x_shape = (1, 2, 1, 100)
183
        self.y_shape = (100,)
S
ShenLiang 已提交
184 185 186 187
        self.trans_x = False
        self.trans_y = False


188
class TestMatMulOp8(TestMatMulV2Op):
S
ShenLiang 已提交
189 190 191 192 193 194 195 196 197 198 199
    """
    case 8
    """

    def config(self):
        self.x_shape = (1, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


200
class TestMatMulOp9(TestMatMulV2Op):
S
ShenLiang 已提交
201 202 203 204 205 206 207 208 209 210 211
    """
    case 9
    """

    def config(self):
        self.x_shape = (1, 1, 1, 100)
        self.y_shape = (2, 1, 2, 100)
        self.trans_x = False
        self.trans_y = True


212
class TestMatMulOp10(TestMatMulV2Op):
S
ShenLiang 已提交
213 214 215 216 217
    """
    case 10
    """

    def config(self):
S
ShenLiang 已提交
218 219
        self.x_shape = (1, 1, 25, 4)
        self.y_shape = (1, 2, 4, 25)
S
ShenLiang 已提交
220 221 222 223
        self.trans_x = False
        self.trans_y = False


224
class TestMatMulOp11(TestMatMulV2Op):
S
ShenLiang 已提交
225 226 227 228 229 230 231 232 233 234 235
    """
    case 11
    """

    def config(self):
        self.x_shape = (2, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


236
class TestMatMulOp12(TestMatMulV2Op):
S
ShenLiang 已提交
237 238 239 240 241
    """
    case 12
    """

    def config(self):
S
ShenLiang 已提交
242 243
        self.x_shape = (2, 1, 4, 25)
        self.y_shape = (1, 1, 4, 25)
S
ShenLiang 已提交
244 245 246 247
        self.trans_x = True
        self.trans_y = False


248
class TestMatMulOp13(TestMatMulV2Op):
S
ShenLiang 已提交
249 250 251 252 253
    """
    case 13
    """

    def config(self):
S
ShenLiang 已提交
254 255
        self.x_shape = (2, 2, 10, 10)
        self.y_shape = (2, 2, 10, 10)
S
ShenLiang 已提交
256 257 258 259
        self.trans_x = True
        self.trans_y = False


260
class TestMatMulOp14(TestMatMulV2Op):
S
ShenLiang 已提交
261 262 263 264 265
    """
    case 14_1
    """

    def config(self):
266 267
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
268 269 270 271
        self.trans_x = True
        self.trans_y = False


272
class TestMatMulOp15(TestMatMulV2Op):
S
ShenLiang 已提交
273 274 275 276 277
    """
    case 14_2
    """

    def config(self):
278 279
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
280 281 282 283
        self.trans_x = False
        self.trans_y = False


284
class TestMatMulOp16(TestMatMulV2Op):
S
ShenLiang 已提交
285 286 287 288 289
    """
    case 16 : to check the gradient for special case
    """

    def config(self):
290
        self.x_shape = 100
S
ShenLiang 已提交
291
        self.y_shape = (1, 2, 2, 100, 2)
S
ShenLiang 已提交
292 293 294 295
        self.trans_x = False
        self.trans_y = False


296
class TestMatMulOp17(TestMatMulV2Op):
S
ShenLiang 已提交
297 298 299 300 301 302
    """
    case 17 : to check the gradient for special case
    """

    def config(self):
        self.x_shape = (2, 1, 100)
303
        self.y_shape = 100
S
ShenLiang 已提交
304 305
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
306 307


308
class TestMatMulOpBroadcast1(TestMatMulV2Op):
309 310 311 312 313 314 315 316 317 318 319
    """
    case 14_3
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = True
        self.trans_y = True


320
class TestMatMulOpBroadcast2(TestMatMulV2Op):
321 322 323 324 325 326 327 328 329 330 331
    """
    case 14_4
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = False
        self.trans_y = True


332
# --------------------test matmul fp16--------------------
S
ShenLiang 已提交
333 334 335


def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
336 337 338
    @unittest.skipIf(
        not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
    )
S
ShenLiang 已提交
339 340 341 342 343 344 345 346
    class TestMatMulOpFp16Case(parent):
        def init_kernel_type(self):
            self.dtype = np.float16

        def test_check_output(self):
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
347 348 349
                    self.check_output_with_place(
                        place, atol=atol, check_eager=False
                    )
S
ShenLiang 已提交
350 351 352 353 354

        def test_check_grad(self):
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                self.check_grad_with_place(
355 356
                    place,
                    ['X', 'Y'],
S
ShenLiang 已提交
357
                    'Out',
358
                    max_relative_error=max_relative_error,
359 360
                    check_eager=False,
                )
S
ShenLiang 已提交
361 362 363 364 365 366 367

    cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
    TestMatMulOpFp16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpFp16Case


create_test_fp16_class(TestMatMulV2Op)
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)

385
# --------------------test matmul bf16--------------------
386 387 388 389


def create_test_bf16_class(parent, atol=0.01):
    @unittest.skipIf(
390 391
        not core.is_compiled_with_cuda()
        or not core.is_bfloat16_supported(core.CUDAPlace(0)),
392 393
        "core is not compiled with CUDA and not support the bfloat16",
    )
394 395 396 397
    class TestMatMulOpBf16Case(parent):
        def get_numeric_grad(self, place, check_name):
            scope = core.Scope()
            self._check_grad_helper()
398 399 400 401 402 403
            op = create_op(
                scope, self.op_type, self.inputs, self.outputs, self.attrs
            )
            return get_numeric_gradient(
                place, scope, op, self.inputs_fp32, check_name, ['Out']
            )
404 405 406 407 408 409 410 411 412 413 414

        def init_kernel_type(self):
            self.dtype = np.uint16

        def test_check_output(self):
            place = core.CUDAPlace(0)
            self.check_output_with_place(place, atol=atol)

        def test_check_grad_x(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'X')
415 416 417 418 419 420 421
            self.check_grad_with_place(
                place,
                ['X'],
                'Out',
                no_grad_set=set(['Y']),
                user_defined_grads=[numeric_grads],
            )
422 423 424 425

        def test_check_grad_y(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'Y')
426 427 428 429 430 431 432
            self.check_grad_with_place(
                place,
                ['Y'],
                'Out',
                no_grad_set=set(['X']),
                user_defined_grads=[numeric_grads],
            )
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458

        def test_check_grad(self):
            pass

    cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
    TestMatMulOpBf16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpBf16Case


create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
S
ShenLiang 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477


class TestMatMulV2API(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
            input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
            input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")

            result = paddle.matmul(input_x, input_y)

            x_np = np.random.random([4, 3]).astype("float32")
            y_np = np.random.random([3, 4]).astype("float32")

            exe = fluid.Executor(place)
478 479 480 481 482
            fetches = exe.run(
                fluid.default_main_program(),
                feed={"input_x": x_np, "input_y": y_np},
                fetch_list=[result],
            )
S
ShenLiang 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495 496

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                input_x = np.random.random([4, 3]).astype("float64")
                input_y = np.random.random([3, 4]).astype("float64")
                x = paddle.to_tensor(input_x)
                y = paddle.to_tensor(input_y)
                result = paddle.matmul(x, y)

S
ShenLiang 已提交
497 498 499 500 501 502 503 504 505 506 507
    def test_dygraph_fp16(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
                    input_x = np.random.random([4, 3]).astype("float16")
                    input_y = np.random.random([3, 4]).astype("float16")
                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)

508 509 510 511 512
    def test_compute_type_fp32(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
513
                    paddle.set_flags(
514 515
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
516 517 518 519 520 521 522 523 524 525 526 527 528
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertTrue(paddle.isfinite(result)[0, 0, 0])
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
529
                    np.testing.assert_array_equal(result_np, result.numpy())
530
                    paddle.set_flags(
531 532
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
533 534 535 536 537 538

    def test_compute_type_fp16_nan(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
539
                    paddle.set_flags(
540 541
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
542 543 544 545 546 547 548 549 550 551 552 553
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertFalse(
554 555
                        paddle.isfinite(result)[0, 0, 0]
                    )  # contains nan/inf
556
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
557
                    paddle.set_flags(
558 559
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
560

561 562 563 564 565
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_dygraph()
            self.test_dygraph_fp16()

S
ShenLiang 已提交
566

C
chentianyu03 已提交
567 568 569 570 571 572 573 574 575
class TestComplexMatMulOp(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
576
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
577 578 579 580 581 582 583 584
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
585 586 587 588 589 590
        self.x = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
591 592 593
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
594 595 596
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
597 598 599 600
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)

    def test_check_output(self):
601
        self.check_output(check_eager=False)
C
chentianyu03 已提交
602 603

    def test_check_grad_normal(self):
604 605 606 607 608 609 610
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
611 612

    def test_check_grad_ingore_x(self):
613 614 615 616 617 618 619 620
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
621 622

    def test_check_grad_ingore_y(self):
623 624 625 626 627 628 629 630
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
631 632 633 634 635 636 637 638 639 640 641


class TestComplexMatMulOpBroadcast(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
642
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
643 644 645 646 647 648 649 650
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
651 652 653 654 655 656
        self.x = np.random.random((10, 2, 5)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 2, 5)).astype(self.dtype)
        self.y = np.random.random((5, 20)).astype(
            self.dtype
        ) + 1j * np.random.random((5, 20)).astype(self.dtype)
C
chentianyu03 已提交
657 658 659
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
660 661 662
        self.grad_out = np.ones((10, 2, 20), self.dtype) + 1j * np.ones(
            (10, 2, 20), self.dtype
        )
C
chentianyu03 已提交
663
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
664 665 666
        self.grad_y = np.sum(
            np.matmul(np.conj(self.x).transpose(0, 2, 1), self.grad_out), axis=0
        )
C
chentianyu03 已提交
667 668

    def test_check_output(self):
669
        self.check_output(check_eager=False)
C
chentianyu03 已提交
670 671

    def test_check_grad_normal(self):
672 673 674 675 676 677 678
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
679 680

    def test_check_grad_ingore_x(self):
681 682 683 684 685 686 687 688
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
689 690

    def test_check_grad_ingore_y(self):
691 692 693 694 695 696 697 698
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
            check_eager=False,
        )
C
chentianyu03 已提交
699 700


C
chentianyu03 已提交
701 702 703
class TestMatMulTypePromotion(TestComplexMatMulOp):
    def init_input_output(self):
        self.x = np.random.random((10, 10)).astype(self.dtype)
704 705 706
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
707 708 709
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
710 711 712
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
713 714 715 716
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)


717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738
class TestMatmulop(unittest.TestCase):
    def func_dygraph_matmul(self):
        paddle.disable_static()

        np_a = np.random.random((2, 4)).astype(np.float32)
        np_b = np.random.random((4, 2)).astype(np.float32)

        tensor_a = paddle.to_tensor(np_a, dtype="float32")
        tensor_b = paddle.to_tensor(np_b, dtype="float32")

        # normal case: tensor @ nparray
        expect_out = np_a @ np_b
        actual_out = tensor_a @ np_b
        np.testing.assert_allclose(actual_out, expect_out)

        paddle.enable_static()

    def func_dygraph_matmul(self):
        with _test_eager_guard():
            self.func_dygraph_matmul()


S
ShenLiang 已提交
739
if __name__ == "__main__":
C
chentianyu03 已提交
740
    paddle.enable_static()
S
ShenLiang 已提交
741
    unittest.main()