test_matmul_v2_op.py 21.4 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

S
ShenLiang 已提交
17
import numpy as np
18 19
from eager_op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from testsuite import create_op
S
ShenLiang 已提交
20 21

import paddle
22 23
from paddle import fluid
from paddle.fluid import core
S
ShenLiang 已提交
24 25 26 27 28 29 30 31


def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
    """Reference forward implementation using np.matmul."""
    # np.matmul does not support the transpose flags, so we manually
    # transpose X and Y appropriately.
    if transpose_X:
        if X.ndim == 1:
32
            X = X.reshape((X.size,))
S
ShenLiang 已提交
33 34 35 36 37 38 39 40
        elif X.ndim == 2:
            X = X.T
        else:
            dim = [i for i in range(len(X.shape))]
            dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
            X = np.transpose(X, tuple(dim))
    if transpose_Y:
        if Y.ndim == 1:
41
            Y = Y.reshape((Y.size,))
S
ShenLiang 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        else:
            dim = [i for i in range(len(Y.shape))]
            dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
            Y = np.transpose(Y, tuple(dim))

    Out = np.matmul(X, Y)
    if not Out.shape:
        # We do not support 0-dimensional Tensors (scalars). So where
        # np.matmul outputs a scalar, we must convert to a Tensor of
        # shape (1, ) instead.
        # Everywhere else, we are compatible with np.matmul.
        Out = np.array([Out], dtype="float64")
    return Out


class TestMatMulV2Op(OpTest):
    """
    case 1
    """

    def config(self):
63 64
        self.x_shape = (100,)
        self.y_shape = (100,)
S
ShenLiang 已提交
65 66
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
67 68

    def init_kernel_type(self):
69
        self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
S
ShenLiang 已提交
70 71

    def setUp(self):
S
ShenLiang 已提交
72
        self.init_kernel_type()
S
ShenLiang 已提交
73 74
        self.config()
        self.op_type = "matmul_v2"
75
        self.python_api = paddle.tensor.matmul
76 77 78 79 80 81 82 83 84
        if self.is_bfloat16_op():
            x = np.random.random(self.x_shape).astype(np.float32)
            y = np.random.random(self.y_shape).astype(np.float32)
        else:
            x = np.random.random(self.x_shape).astype(self.dtype)
            y = np.random.random(self.y_shape).astype(self.dtype)
            # -0.1 ~ 0.1
            x = -0.1 + 0.2 * x
            y = -0.1 + 0.2 * y
S
ShenLiang 已提交
85
        result = reference_matmul(x, y, self.trans_x, self.trans_y)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        if self.is_bfloat16_op():
            result = result.astype(np.float32)
            self.inputs = {
                'X': convert_float_to_uint16(x),
                'Y': convert_float_to_uint16(y),
            }
            self.inputs_fp32 = {
                'X': x,
                'Y': y,
            }
        else:
            result = result.astype(self.dtype)
            self.inputs = {
                'X': x,
                'Y': y,
            }
S
ShenLiang 已提交
102 103 104 105
        self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
        self.outputs = {'Out': result}

    def test_check_output(self):
106
        self.check_output()
S
ShenLiang 已提交
107 108

    def test_check_grad(self):
109
        if core.is_compiled_with_rocm():
110
            self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-2)
111
        else:
112
            self.check_grad(['X', 'Y'], 'Out')
S
ShenLiang 已提交
113 114


115
class TestMatMulOp2(TestMatMulV2Op):
S
ShenLiang 已提交
116 117 118 119 120
    """
    case 2
    """

    def config(self):
121
        self.x_shape = (100,)
S
ShenLiang 已提交
122 123 124 125 126
        self.y_shape = (1, 3, 2, 100)
        self.trans_x = False
        self.trans_y = True


127
class TestMatMulOp3(TestMatMulV2Op):
S
ShenLiang 已提交
128 129 130 131 132
    """
    case 3
    """

    def config(self):
133
        self.x_shape = (100,)
S
ShenLiang 已提交
134 135 136 137 138
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


139
class TestMatMulOp4(TestMatMulV2Op):
S
ShenLiang 已提交
140 141 142 143 144
    """
    case 4
    """

    def config(self):
145
        self.x_shape = (100,)
S
ShenLiang 已提交
146 147 148 149 150
        self.y_shape = (1, 2, 100, 2)
        self.trans_x = False
        self.trans_y = False


151
class TestMatMulOp5(TestMatMulV2Op):
S
ShenLiang 已提交
152 153 154 155 156
    """
    case 5
    """

    def config(self):
S
ShenLiang 已提交
157
        self.x_shape = (1, 1, 100, 1)
158
        self.y_shape = (100,)
S
ShenLiang 已提交
159 160 161 162
        self.trans_x = True
        self.trans_y = False


163
class TestMatMulOp6(TestMatMulV2Op):
S
ShenLiang 已提交
164 165 166 167 168
    """
    case 6
    """

    def config(self):
169
        self.x_shape = (1, 2, 102, 1)
170
        self.y_shape = (102,)
S
ShenLiang 已提交
171 172 173 174
        self.trans_x = True
        self.trans_y = False


175
class TestMatMulOp7(TestMatMulV2Op):
S
ShenLiang 已提交
176 177 178 179 180 181
    """
    case 7
    """

    def config(self):
        self.x_shape = (1, 2, 1, 100)
182
        self.y_shape = (100,)
S
ShenLiang 已提交
183 184 185 186
        self.trans_x = False
        self.trans_y = False


187
class TestMatMulOp8(TestMatMulV2Op):
S
ShenLiang 已提交
188 189 190 191 192 193 194 195 196 197 198
    """
    case 8
    """

    def config(self):
        self.x_shape = (1, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


199
class TestMatMulOp9(TestMatMulV2Op):
S
ShenLiang 已提交
200 201 202 203 204 205 206 207 208 209 210
    """
    case 9
    """

    def config(self):
        self.x_shape = (1, 1, 1, 100)
        self.y_shape = (2, 1, 2, 100)
        self.trans_x = False
        self.trans_y = True


211
class TestMatMulOp10(TestMatMulV2Op):
S
ShenLiang 已提交
212 213 214 215 216
    """
    case 10
    """

    def config(self):
S
ShenLiang 已提交
217 218
        self.x_shape = (1, 1, 25, 4)
        self.y_shape = (1, 2, 4, 25)
S
ShenLiang 已提交
219 220 221 222
        self.trans_x = False
        self.trans_y = False


223
class TestMatMulOp11(TestMatMulV2Op):
S
ShenLiang 已提交
224 225 226 227 228 229 230 231 232 233 234
    """
    case 11
    """

    def config(self):
        self.x_shape = (2, 1, 2, 100)
        self.y_shape = (1, 1, 100, 2)
        self.trans_x = False
        self.trans_y = False


235
class TestMatMulOp12(TestMatMulV2Op):
S
ShenLiang 已提交
236 237 238 239 240
    """
    case 12
    """

    def config(self):
S
ShenLiang 已提交
241 242
        self.x_shape = (2, 1, 4, 25)
        self.y_shape = (1, 1, 4, 25)
S
ShenLiang 已提交
243 244 245 246
        self.trans_x = True
        self.trans_y = False


247
class TestMatMulOp13(TestMatMulV2Op):
S
ShenLiang 已提交
248 249 250 251 252
    """
    case 13
    """

    def config(self):
S
ShenLiang 已提交
253 254
        self.x_shape = (2, 2, 10, 10)
        self.y_shape = (2, 2, 10, 10)
S
ShenLiang 已提交
255 256 257 258
        self.trans_x = True
        self.trans_y = False


259
class TestMatMulOp14(TestMatMulV2Op):
S
ShenLiang 已提交
260 261 262 263 264
    """
    case 14_1
    """

    def config(self):
265 266
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
267 268 269 270
        self.trans_x = True
        self.trans_y = False


271
class TestMatMulOp15(TestMatMulV2Op):
S
ShenLiang 已提交
272 273 274 275 276
    """
    case 14_2
    """

    def config(self):
277 278
        self.x_shape = (3, 1, 6, 6)
        self.y_shape = (1, 2, 6, 9)
S
ShenLiang 已提交
279 280 281 282
        self.trans_x = False
        self.trans_y = False


283
class TestMatMulOp16(TestMatMulV2Op):
S
ShenLiang 已提交
284 285 286 287 288
    """
    case 16 : to check the gradient for special case
    """

    def config(self):
289
        self.x_shape = 100
S
ShenLiang 已提交
290
        self.y_shape = (1, 2, 2, 100, 2)
S
ShenLiang 已提交
291 292 293 294
        self.trans_x = False
        self.trans_y = False


295
class TestMatMulOp17(TestMatMulV2Op):
S
ShenLiang 已提交
296 297 298 299 300 301
    """
    case 17 : to check the gradient for special case
    """

    def config(self):
        self.x_shape = (2, 1, 100)
302
        self.y_shape = 100
S
ShenLiang 已提交
303 304
        self.trans_x = False
        self.trans_y = False
S
ShenLiang 已提交
305 306


307
class TestMatMulOpBroadcast1(TestMatMulV2Op):
308 309 310 311 312 313 314 315 316 317 318
    """
    case 14_3
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = True
        self.trans_y = True


319
class TestMatMulOpBroadcast2(TestMatMulV2Op):
320 321 322 323 324 325 326 327 328 329 330
    """
    case 14_4
    """

    def config(self):
        self.x_shape = (3, 1, 10, 10)
        self.y_shape = (1, 2, 10, 10)
        self.trans_x = False
        self.trans_y = True


331
# --------------------test matmul fp16--------------------
S
ShenLiang 已提交
332 333 334


def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
335 336 337
    @unittest.skipIf(
        not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
    )
S
ShenLiang 已提交
338 339 340 341 342 343 344 345
    class TestMatMulOpFp16Case(parent):
        def init_kernel_type(self):
            self.dtype = np.float16

        def test_check_output(self):
            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
346
                    self.check_output_with_place(place, atol=atol)
S
ShenLiang 已提交
347 348 349 350 351

        def test_check_grad(self):
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                self.check_grad_with_place(
352 353
                    place,
                    ['X', 'Y'],
S
ShenLiang 已提交
354
                    'Out',
355
                    max_relative_error=max_relative_error,
356
                )
S
ShenLiang 已提交
357 358 359 360 361 362 363

    cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
    TestMatMulOpFp16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpFp16Case


create_test_fp16_class(TestMatMulV2Op)
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)
380 381
create_test_fp16_class(TestMatMulOpBroadcast1)
create_test_fp16_class(TestMatMulOpBroadcast2)
382

383
# --------------------test matmul bf16--------------------
384 385 386 387


def create_test_bf16_class(parent, atol=0.01):
    @unittest.skipIf(
388 389
        not core.is_compiled_with_cuda()
        or not core.is_bfloat16_supported(core.CUDAPlace(0)),
390 391
        "core is not compiled with CUDA and not support the bfloat16",
    )
392 393 394 395
    class TestMatMulOpBf16Case(parent):
        def get_numeric_grad(self, place, check_name):
            scope = core.Scope()
            self._check_grad_helper()
396 397 398 399 400 401
            op = create_op(
                scope, self.op_type, self.inputs, self.outputs, self.attrs
            )
            return get_numeric_gradient(
                place, scope, op, self.inputs_fp32, check_name, ['Out']
            )
402 403 404 405 406 407 408 409 410 411 412

        def init_kernel_type(self):
            self.dtype = np.uint16

        def test_check_output(self):
            place = core.CUDAPlace(0)
            self.check_output_with_place(place, atol=atol)

        def test_check_grad_x(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'X')
413 414 415 416 417 418 419
            self.check_grad_with_place(
                place,
                ['X'],
                'Out',
                no_grad_set=set(['Y']),
                user_defined_grads=[numeric_grads],
            )
420 421 422 423

        def test_check_grad_y(self):
            place = core.CUDAPlace(0)
            numeric_grads = self.get_numeric_grad(place, 'Y')
424 425 426 427 428 429 430
            self.check_grad_with_place(
                place,
                ['Y'],
                'Out',
                no_grad_set=set(['X']),
                user_defined_grads=[numeric_grads],
            )
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456

        def test_check_grad(self):
            pass

    cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
    TestMatMulOpBf16Case.__name__ = cls_name
    globals()[cls_name] = TestMatMulOpBf16Case


create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
S
ShenLiang 已提交
457 458 459 460 461 462 463 464 465 466


class TestMatMulV2API(unittest.TestCase):
    def setUp(self):
        self.places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.places.append(fluid.CUDAPlace(0))

    def check_static_result(self, place):
        with fluid.program_guard(fluid.Program(), fluid.Program()):
467 468 469 470 471 472
            input_x = paddle.static.data(
                name="input_x", shape=[4, 3], dtype="float32"
            )
            input_y = paddle.static.data(
                name="input_y", shape=[3, 4], dtype="float32"
            )
S
ShenLiang 已提交
473 474 475 476 477 478 479

            result = paddle.matmul(input_x, input_y)

            x_np = np.random.random([4, 3]).astype("float32")
            y_np = np.random.random([3, 4]).astype("float32")

            exe = fluid.Executor(place)
480 481 482 483 484
            fetches = exe.run(
                fluid.default_main_program(),
                feed={"input_x": x_np, "input_y": y_np},
                fetch_list=[result],
            )
S
ShenLiang 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497 498

    def test_static(self):
        for place in self.places:
            self.check_static_result(place=place)

    def test_dygraph(self):
        for place in self.places:
            with fluid.dygraph.guard(place):
                input_x = np.random.random([4, 3]).astype("float64")
                input_y = np.random.random([3, 4]).astype("float64")
                x = paddle.to_tensor(input_x)
                y = paddle.to_tensor(input_y)
                result = paddle.matmul(x, y)

S
ShenLiang 已提交
499 500 501 502 503 504 505 506 507 508 509
    def test_dygraph_fp16(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
                    input_x = np.random.random([4, 3]).astype("float16")
                    input_y = np.random.random([3, 4]).astype("float16")
                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)

510 511 512 513 514
    def test_compute_type_fp32(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
515
                    paddle.set_flags(
516 517
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
518 519 520 521 522 523 524 525 526 527 528 529 530
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertTrue(paddle.isfinite(result)[0, 0, 0])
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
531
                    np.testing.assert_array_equal(result_np, result.numpy())
532
                    paddle.set_flags(
533 534
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
535 536 537 538 539 540

    def test_compute_type_fp16_nan(self):
        if core.is_compiled_with_cuda():
            place = core.CUDAPlace(0)
            if core.is_float16_supported(place):
                with fluid.dygraph.guard(place):
541
                    paddle.set_flags(
542 543
                        {'FLAGS_gemm_use_half_precision_compute_type': True}
                    )
544 545 546 547 548 549 550 551 552 553 554 555
                    input_x = np.random.random([2, 8, 16]).astype("float16")
                    input_y = np.random.random([2, 16, 8]).astype("float16")
                    for i in range(0, 16, 2):
                        input_x[:, :, i] += 60000
                        input_x[:, :, i + 1] -= 60000
                    input_y[:, :, :] = 1.5

                    x = paddle.to_tensor(input_x)
                    y = paddle.to_tensor(input_y)
                    result = paddle.matmul(x, y)
                    result_np = np.matmul(input_x, input_y)
                    self.assertFalse(
556 557
                        paddle.isfinite(result)[0, 0, 0]
                    )  # contains nan/inf
558
                    self.assertTrue(np.isfinite(result_np)[0, 0, 0])
559
                    paddle.set_flags(
560 561
                        {'FLAGS_gemm_use_half_precision_compute_type': False}
                    )
562

S
ShenLiang 已提交
563

C
chentianyu03 已提交
564 565 566
class TestComplexMatMulOp(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
567
        self.python_api = paddle.tensor.matmul
C
chentianyu03 已提交
568 569 570 571 572 573
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
574
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
575 576 577 578 579 580 581 582
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
583 584 585 586 587 588
        self.x = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
589 590 591
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
592 593 594
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
595 596 597 598
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)

    def test_check_output(self):
599
        self.check_output()
C
chentianyu03 已提交
600 601

    def test_check_grad_normal(self):
602 603 604 605 606 607
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
608 609

    def test_check_grad_ingore_x(self):
610 611 612 613 614 615 616
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
617 618

    def test_check_grad_ingore_y(self):
619 620 621 622 623 624 625
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
626 627 628 629 630


class TestComplexMatMulOpBroadcast(OpTest):
    def setUp(self):
        self.op_type = "matmul_v2"
631
        self.python_api = paddle.tensor.matmul
C
chentianyu03 已提交
632 633 634 635 636 637
        self.init_base_dtype()
        self.init_input_output()
        self.init_grad_input_output()

        self.inputs = {
            'X': OpTest.np_dtype_to_fluid_dtype(self.x),
638
            'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
C
chentianyu03 已提交
639 640 641 642 643 644 645 646
        }
        self.attrs = {'axis': -1, 'use_mkldnn': False}
        self.outputs = {'Out': self.out}

    def init_base_dtype(self):
        self.dtype = np.float64

    def init_input_output(self):
647 648 649 650 651 652
        self.x = np.random.random((10, 2, 5)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 2, 5)).astype(self.dtype)
        self.y = np.random.random((5, 20)).astype(
            self.dtype
        ) + 1j * np.random.random((5, 20)).astype(self.dtype)
C
chentianyu03 已提交
653 654 655
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
656 657 658
        self.grad_out = np.ones((10, 2, 20), self.dtype) + 1j * np.ones(
            (10, 2, 20), self.dtype
        )
C
chentianyu03 已提交
659
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
660 661 662
        self.grad_y = np.sum(
            np.matmul(np.conj(self.x).transpose(0, 2, 1), self.grad_out), axis=0
        )
C
chentianyu03 已提交
663 664

    def test_check_output(self):
665
        self.check_output()
C
chentianyu03 已提交
666 667

    def test_check_grad_normal(self):
668 669 670 671 672 673
        self.check_grad(
            ['X', 'Y'],
            'Out',
            user_defined_grads=[self.grad_x, self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
674 675

    def test_check_grad_ingore_x(self):
676 677 678 679 680 681 682
        self.check_grad(
            ['Y'],
            'Out',
            no_grad_set=set("X"),
            user_defined_grads=[self.grad_y],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
683 684

    def test_check_grad_ingore_y(self):
685 686 687 688 689 690 691
        self.check_grad(
            ['X'],
            'Out',
            no_grad_set=set('Y'),
            user_defined_grads=[self.grad_x],
            user_defined_grad_outputs=[self.grad_out],
        )
C
chentianyu03 已提交
692 693


C
chentianyu03 已提交
694 695 696
class TestMatMulTypePromotion(TestComplexMatMulOp):
    def init_input_output(self):
        self.x = np.random.random((10, 10)).astype(self.dtype)
697 698 699
        self.y = np.random.random((10, 10)).astype(
            self.dtype
        ) + 1j * np.random.random((10, 10)).astype(self.dtype)
C
chentianyu03 已提交
700 701 702
        self.out = np.dot(self.x, self.y)

    def init_grad_input_output(self):
703 704 705
        self.grad_out = np.ones((10, 10), self.dtype) + 1j * np.ones(
            (10, 10), self.dtype
        )
C
chentianyu03 已提交
706 707 708 709
        self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
        self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)


710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
class TestMatmulop(unittest.TestCase):
    def func_dygraph_matmul(self):
        paddle.disable_static()

        np_a = np.random.random((2, 4)).astype(np.float32)
        np_b = np.random.random((4, 2)).astype(np.float32)

        tensor_a = paddle.to_tensor(np_a, dtype="float32")
        tensor_b = paddle.to_tensor(np_b, dtype="float32")

        # normal case: tensor @ nparray
        expect_out = np_a @ np_b
        actual_out = tensor_a @ np_b
        np.testing.assert_allclose(actual_out, expect_out)

        paddle.enable_static()


S
ShenLiang 已提交
728
if __name__ == "__main__":
C
chentianyu03 已提交
729
    paddle.enable_static()
S
ShenLiang 已提交
730
    unittest.main()