blas_impl.hip.h 88.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
19 20
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h"
21 22 23

DECLARE_bool(enable_cublas_tensor_op_math);

24
namespace phi {
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
namespace funcs {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_sgemm(args...));
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_saxpy(args...));
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_sscal(args...));
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_scopy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_sgemv(args...));
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_sgemm_strided_batched(args...));
  }

  // HIP not supportted, refer to the doc here:
  // https://github.com/ROCm-Developer-Tools/HIP/blob/roc-3.5.x/docs/markdown/CUBLAS_API_supported_by_HIP.md
  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
72
    PADDLE_THROW(phi::errors::Unimplemented(
73 74 75 76 77 78 79 80 81 82 83
        "cublasSgemmEx is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_strsm(args...));
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
84
    PADDLE_THROW(phi::errors::Unimplemented(
85 86 87 88 89
        "cublasSgetrfBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
90
    PADDLE_THROW(phi::errors::Unimplemented(
91 92 93 94 95
        "cublasSgetriBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
96
    PADDLE_THROW(phi::errors::Unimplemented(
97 98 99 100 101
        "cublasSmatinvBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
102
    PADDLE_THROW(phi::errors::Unimplemented(
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        "cublasStrsmBatched is not supported on HIP platform."));
  }
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dgemm(args...));
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_daxpy(args...));
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dscal(args...));
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dcopy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dgemv(args...));
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dgemm_strided_batched(args...));
  }

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW(
148
        phi::errors::Unimplemented("Currently there are not cublasDgemmEx."));
149 150 151 152 153 154 155 156 157 158
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_dtrsm(args...));
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
159
    PADDLE_THROW(phi::errors::Unimplemented(
160 161 162 163 164
        "cublasDgetrfBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
165
    PADDLE_THROW(phi::errors::Unimplemented(
166 167 168 169 170
        "cublasDgetriBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
171
    PADDLE_THROW(phi::errors::Unimplemented(
172 173 174 175 176
        "cublasDmatinvBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
177
    PADDLE_THROW(phi::errors::Unimplemented(
178 179 180 181 182
        "cublasDtrsmBatched is not supported on HIP platform."));
  }
};

template <>
183 184
struct CUBlas<phi::dtype::float16> {
  using float16 = phi::dtype::float16;
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307

  static void GEMM(rocblas_handle handle,
                   rocblas_operation transa,
                   rocblas_operation transb,
                   int m,
                   int n,
                   int k,
                   const float16 *alpha,
                   const float16 *A,
                   int lda,
                   const float16 *B,
                   int ldb,
                   const float16 *beta,
                   float16 *C,
                   int ldc) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_hgemm(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const rocblas_half *>(alpha),
        reinterpret_cast<const rocblas_half *>(A),
        lda,
        reinterpret_cast<const rocblas_half *>(B),
        ldb,
        reinterpret_cast<const rocblas_half *>(beta),
        reinterpret_cast<rocblas_half *>(C),
        ldc));
  }

  static void GEMM_STRIDED_BATCH(rocblas_handle handle,
                                 rocblas_operation transa,
                                 rocblas_operation transb,
                                 int m,
                                 int n,
                                 int k,
                                 const float16 *alpha,
                                 const float16 *A,
                                 int lda,
                                 long long int strideA,  // NOLINT
                                 const float16 *B,       // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const float16 *beta,
                                 float16 *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_hgemm_strided_batched(
            handle,
            transa,
            transb,
            m,
            n,
            k,
            reinterpret_cast<const rocblas_half *>(alpha),
            reinterpret_cast<const rocblas_half *>(A),
            lda,
            strideA,
            reinterpret_cast<const rocblas_half *>(B),
            ldb,
            strideB,
            reinterpret_cast<const rocblas_half *>(beta),
            reinterpret_cast<rocblas_half *>(C),
            ldc,
            strideC,
            batchCount));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
  template <typename... ARGS>
308
  static void GEMM_EX(phi::GPUContext *dev_ctx,
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
};

template <>
358
struct CUBlas<phi::dtype::complex<float>> {
359 360 361 362
  static void GEMV(rocblas_handle handle,
                   rocblas_operation transa,
                   int m,
                   int n,
363 364
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
365
                   int lda,
366
                   const phi::dtype::complex<float> *B,
367
                   int ldb,
368 369
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
                   int ldc) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_cgemv(
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(A),
        lda,
        reinterpret_cast<const rocblas_float_complex *>(B),
        ldb,
        reinterpret_cast<const rocblas_float_complex *>(beta),
        reinterpret_cast<rocblas_float_complex *>(C),
        ldc));
  }

  static void AXPY(rocblas_handle handle,
                   int n,
388 389
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *X,
390
                   const int incX,
391
                   phi::dtype::complex<float> *Y,
392 393 394 395 396 397 398 399 400 401 402
                   const int incY) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_caxpy(
        handle,
        n,
        reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(X),
        incX,
        reinterpret_cast<rocblas_float_complex *>(Y),
        incY));
  }

403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
  static void GEMM_STRIDED_BATCH(rocblas_handle handle,
                                 rocblas_operation transa,
                                 rocblas_operation transb,
                                 int m,
                                 int n,
                                 int k,
                                 const phi::dtype::complex<float> *alpha,
                                 const phi::dtype::complex<float> *A,
                                 int lda,
                                 long long int strideA,                // NOLINT
                                 const phi::dtype::complex<float> *B,  // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const phi::dtype::complex<float> *beta,
                                 phi::dtype::complex<float> *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_cgemm_strided_batched(
            handle,
            transa,
            transb,
            m,
            n,
            k,
            reinterpret_cast<const rocblas_float_complex *>(alpha),
            reinterpret_cast<const rocblas_float_complex *>(A),
            lda,
            strideA,
            reinterpret_cast<const rocblas_float_complex *>(B),
            ldb,
            strideB,
            reinterpret_cast<const rocblas_float_complex *>(beta),
            reinterpret_cast<rocblas_float_complex *>(C),
            ldc,
            strideC,
            batchCount));
  }

  static void GEMM(rocblas_handle handle,
                   rocblas_operation transa,
                   rocblas_operation transb,
                   int m,
                   int n,
                   int k,
449 450
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
451
                   int lda,
452
                   const phi::dtype::complex<float> *B,
453
                   int ldb,
454 455
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
                   int ldc) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_cgemm(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(A),
        lda,
        reinterpret_cast<const rocblas_float_complex *>(B),
        ldb,
        reinterpret_cast<const rocblas_float_complex *>(beta),
        reinterpret_cast<rocblas_float_complex *>(C),
        ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
  template <typename... ARGS>
525
  static void GEMM_EX(phi::GPUContext *dev_ctx,
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
};

template <>
575
struct CUBlas<phi::dtype::complex<double>> {
576 577 578 579
  static void GEMV(rocblas_handle handle,
                   rocblas_operation transa,
                   int m,
                   int n,
580 581
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
582
                   int lda,
583
                   const phi::dtype::complex<double> *B,
584
                   int ldb,
585 586
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
                   int ldc) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zgemv(
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(A),
        lda,
        reinterpret_cast<const rocblas_double_complex *>(B),
        ldb,
        reinterpret_cast<const rocblas_double_complex *>(beta),
        reinterpret_cast<rocblas_double_complex *>(C),
        ldc));
  }

  static void AXPY(rocblas_handle handle,
                   int n,
605 606
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *X,
607
                   const int incX,
608
                   phi::dtype::complex<double> *Y,
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
                   const int incY) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zaxpy(
        handle,
        n,
        reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(X),
        incX,
        reinterpret_cast<rocblas_double_complex *>(Y),
        incY));
  }

  static void GEMM_STRIDED_BATCH(
      rocblas_handle handle,
      rocblas_operation transa,
      rocblas_operation transb,
      int m,
      int n,
      int k,
627 628
      const phi::dtype::complex<double> *alpha,
      const phi::dtype::complex<double> *A,
629
      int lda,
630 631
      long long int strideA,                 // NOLINT
      const phi::dtype::complex<double> *B,  // NOLINT
632 633
      int ldb,
      long long int strideB,  // NOLINT
634 635
      const phi::dtype::complex<double> *beta,
      phi::dtype::complex<double> *C,
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
      int ldc,
      long long int strideC,  // NOLINT
      int batchCount) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_zgemm_strided_batched(
            handle,
            transa,
            transb,
            m,
            n,
            k,
            reinterpret_cast<const rocblas_double_complex *>(alpha),
            reinterpret_cast<const rocblas_double_complex *>(A),
            lda,
            strideA,
            reinterpret_cast<const rocblas_double_complex *>(B),
            ldb,
            strideB,
            reinterpret_cast<const rocblas_double_complex *>(beta),
            reinterpret_cast<rocblas_double_complex *>(C),
            ldc,
            strideC,
            batchCount));
  }

  static void GEMM(rocblas_handle handle,
                   rocblas_operation transa,
                   rocblas_operation transb,
                   int m,
                   int n,
                   int k,
667 668
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
669
                   int lda,
670
                   const phi::dtype::complex<double> *B,
671
                   int ldb,
672 673
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
                   int ldc) {
    PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::rocblas_zgemm(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(A),
        lda,
        reinterpret_cast<const rocblas_double_complex *>(B),
        ldb,
        reinterpret_cast<const rocblas_double_complex *>(beta),
        reinterpret_cast<rocblas_double_complex *>(C),
        ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(paddle::platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
  template <typename... ARGS>
743
  static void GEMM_EX(phi::GPUContext *dev_ctx,
744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
                      rocblas_operation transa,
                      rocblas_operation transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      rocblas_datatype Atype,
                      int lda,
                      const void *B,
                      rocblas_datatype Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      rocblas_datatype Ctype,
                      int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                     transa,
                                                     transb,
                                                     m,
                                                     n,
                                                     k,
                                                     alpha,
                                                     A,
                                                     Atype,
                                                     lda,
                                                     B,
                                                     Btype,
                                                     ldb,
                                                     beta,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     C,
                                                     Ctype,
                                                     ldc,
                                                     computeType,
                                                     algo,
                                                     0,
                                                     0));
    });
  }
};

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                                                     CBLAS_TRANSPOSE transB,
                                                     int M,
                                                     int N,
                                                     int K,
                                                     T alpha,
                                                     const T *A,
                                                     const T *B,
                                                     T beta,
                                                     T *C) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle,
                    cuTransB,
                    cuTransA,
                    N,
                    M,
                    K,
                    &alpha,
                    B,
                    ldb,
                    A,
                    lda,
                    &beta,
                    C,
                    N);
  });
}
template <>
template <typename T>
833 834 835 836 837 838 839 840 841 842
void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                 CBLAS_TRANSPOSE transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
                                 T *C) const {
843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle,
                    cuTransB,
                    cuTransA,
                    N,
                    M,
                    K,
                    &alpha,
                    B,
                    ldb,
                    A,
                    lda,
                    &beta,
                    C,
                    N);
  });
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
879 880 881 882 883
    phi::dtype::float16 alpha,
    const phi::dtype::float16 *A,
    const phi::dtype::float16 *B,
    phi::dtype::float16 beta,
    phi::dtype::float16 *C) const {
884 885 886 887 888 889 890 891 892 893 894 895 896 897 898
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
899
      phi::errors::InvalidArgument(
900 901 902 903 904 905 906 907
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  auto &cuda_ctx = const_cast<paddle::platform::CUDADeviceContext &>(context_);
908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925
  CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
                                       cuTransB,
                                       cuTransA,
                                       N,
                                       M,
                                       K,
                                       &h_alpha,
                                       B,
                                       rocblas_datatype_f16_r,
                                       ldb,
                                       A,
                                       rocblas_datatype_f16_r,
                                       lda,
                                       &h_beta,
                                       C,
                                       rocblas_datatype_f16_r,
                                       N,
                                       rocblas_datatype_f32_r);
926 927 928
}
template <>
template <>
929 930 931 932 933 934 935 936 937 938
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
939 940 941 942 943 944 945 946 947 948 949 950 951 952 953
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
954
      phi::errors::InvalidArgument(
955 956 957 958 959 960 961
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
                                       cuTransB,
                                       cuTransA,
                                       N,
                                       M,
                                       K,
                                       &h_alpha,
                                       B,
                                       rocblas_datatype_f16_r,
                                       ldb,
                                       A,
                                       rocblas_datatype_f16_r,
                                       lda,
                                       &h_beta,
                                       C,
                                       rocblas_datatype_f16_r,
                                       N,
                                       rocblas_datatype_f32_r);
981 982 983 984 985 986 987 988 989 990
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
991 992 993 994 995
    phi::dtype::bfloat16 alpha,
    const phi::dtype::bfloat16 *A,
    const phi::dtype::bfloat16 *B,
    phi::dtype::bfloat16 beta,
    phi::dtype::bfloat16 *C) const {
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  // TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      80,
1010
      phi::errors::InvalidArgument(
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
          "rocblas fp16 gemm requires GPU compute capability >= 80,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);
  rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

  context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                   cuTransB,
                                                   cuTransA,
                                                   N,
                                                   M,
                                                   K,
                                                   &h_alpha,
                                                   B,
                                                   rocblas_datatype_bf16_r,
                                                   ldb,
                                                   A,
                                                   rocblas_datatype_bf16_r,
                                                   lda,
                                                   &h_beta,
                                                   C,
                                                   rocblas_datatype_bf16_r,
                                                   N,
                                                   C,
                                                   rocblas_datatype_bf16_r,
                                                   N,
                                                   rocblas_datatype_f32_r,
                                                   algo,
                                                   0,
                                                   0));
  });
}

template <>
template <>
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  // TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      80,
1074
      phi::errors::InvalidArgument(
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119
          "rocblas fp16 gemm requires GPU compute capability >= 80,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);
  rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

  context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_gemm_ex(handle,
                                                   cuTransB,
                                                   cuTransA,
                                                   N,
                                                   M,
                                                   K,
                                                   &h_alpha,
                                                   B,
                                                   rocblas_datatype_bf16_r,
                                                   ldb,
                                                   A,
                                                   rocblas_datatype_bf16_r,
                                                   lda,
                                                   &h_beta,
                                                   C,
                                                   rocblas_datatype_bf16_r,
                                                   N,
                                                   C,
                                                   rocblas_datatype_bf16_r,
                                                   N,
                                                   rocblas_datatype_f32_r,
                                                   algo,
                                                   0,
                                                   0));
  });
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
1120 1121 1122 1123 1124
    phi::dtype::complex<float> alpha,
    const phi::dtype::complex<float> *A,
    const phi::dtype::complex<float> *B,
    phi::dtype::complex<float> beta,
    phi::dtype::complex<float> *C) const {
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1140
      phi::errors::InvalidArgument(
1141 1142 1143 1144 1145 1146 1147 1148 1149
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

  auto &cuda_ctx = const_cast<paddle::platform::CUDADeviceContext &>(context_);
1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
  CUBlas<phi::dtype::complex<float>>::GEMM_EX(&cuda_ctx,
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
                                              B,
                                              rocblas_datatype_f32_c,
                                              ldb,
                                              A,
                                              rocblas_datatype_f32_c,
                                              lda,
                                              &c_beta,
                                              C,
                                              rocblas_datatype_f32_c,
                                              N,
                                              rocblas_datatype_f32_c);
1168 1169 1170
}
template <>
template <>
1171 1172 1173 1174 1175 1176 1177 1178 1179 1180
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<float> alpha,
                                        const phi::dtype::complex<float> *A,
                                        const phi::dtype::complex<float> *B,
                                        phi::dtype::complex<float> beta,
                                        phi::dtype::complex<float> *C) const {
1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1196
      phi::errors::InvalidArgument(
1197 1198 1199 1200 1201 1202 1203 1204
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::complex<float>>::GEMM_EX(&cuda_ctx,
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
                                              B,
                                              rocblas_datatype_f32_c,
                                              ldb,
                                              A,
                                              rocblas_datatype_f32_c,
                                              lda,
                                              &c_beta,
                                              C,
                                              rocblas_datatype_f32_c,
                                              N,
                                              rocblas_datatype_f32_c);
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
1234 1235 1236 1237 1238
    phi::dtype::complex<double> alpha,
    const phi::dtype::complex<double> *A,
    const phi::dtype::complex<double> *B,
    phi::dtype::complex<double> beta,
    phi::dtype::complex<double> *C) const {
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1254
      phi::errors::InvalidArgument(
1255 1256 1257 1258 1259 1260 1261 1262 1263 1264
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

  auto &cuda_ctx = const_cast<paddle::platform::CUDADeviceContext &>(context_);
1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282
  CUBlas<phi::dtype::complex<double>>::GEMM_EX(&cuda_ctx,
                                               cuTransB,
                                               cuTransA,
                                               N,
                                               M,
                                               K,
                                               &c_alpha,
                                               B,
                                               rocblas_datatype_f64_c,
                                               ldb,
                                               A,
                                               rocblas_datatype_f64_c,
                                               lda,
                                               &c_beta,
                                               C,
                                               rocblas_datatype_f64_c,
                                               N,
                                               rocblas_datatype_f64_c);
1283 1284 1285
}
template <>
template <>
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<double> alpha,
                                        const phi::dtype::complex<double> *A,
                                        const phi::dtype::complex<double> *B,
                                        phi::dtype::complex<double> beta,
                                        phi::dtype::complex<double> *C) const {
1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1311
      phi::errors::InvalidArgument(
1312 1313 1314 1315 1316 1317 1318 1319 1320
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::complex<double>>::GEMM_EX(&cuda_ctx,
                                               cuTransB,
                                               cuTransA,
                                               N,
                                               M,
                                               K,
                                               &c_alpha,
                                               B,
                                               rocblas_datatype_f64_c,
                                               ldb,
                                               A,
                                               rocblas_datatype_f64_c,
                                               lda,
                                               &c_beta,
                                               C,
                                               rocblas_datatype_f64_c,
                                               N,
                                               rocblas_datatype_f64_c);
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::GEMM(bool transA,
                                                     bool transB,
                                                     int M,
                                                     int N,
                                                     int K,
                                                     T alpha,
                                                     const T *A,
                                                     int lda,
                                                     const T *B,
                                                     int ldb,
                                                     T beta,
                                                     T *C,
                                                     int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle,
                    cuTransB,
                    cuTransA,
                    N,
                    M,
                    K,
                    &alpha,
                    B,
                    ldb,
                    A,
                    lda,
                    &beta,
                    C,
                    ldc);
  });
}
template <>
template <typename T>
1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394
void Blas<phi::GPUContext>::GEMM(bool transA,
                                 bool transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 const T *B,
                                 int ldb,
                                 T beta,
                                 T *C,
                                 int ldc) const {
1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle,
                    cuTransB,
                    cuTransA,
                    N,
                    M,
                    K,
                    &alpha,
                    B,
                    ldb,
                    A,
                    lda,
                    &beta,
                    C,
                    ldc);
  });
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMM(
    bool transA,
    bool transB,
    int M,
    int N,
    int K,
1427 1428
    phi::dtype::float16 alpha,
    const phi::dtype::float16 *A,
1429
    int lda,
1430
    const phi::dtype::float16 *B,
1431
    int ldb,
1432 1433
    phi::dtype::float16 beta,
    phi::dtype::float16 *C,
1434 1435 1436 1437 1438 1439 1440 1441 1442
    int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &alpha,
                                      B,
                                      ldb,
                                      A,
                                      lda,
                                      &beta,
                                      C,
                                      ldc);
1457 1458 1459 1460
  });
}
template <>
template <>
1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
inline void Blas<phi::GPUContext>::GEMM(bool transA,
                                        bool transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        int lda,
                                        const phi::dtype::float16 *B,
                                        int ldb,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C,
                                        int ldc) const {
1474 1475 1476 1477 1478 1479 1480 1481
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &alpha,
                                      B,
                                      ldb,
                                      A,
                                      lda,
                                      &beta,
                                      C,
                                      ldc);
1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::AXPY(int n,
                                                     T alpha,
                                                     const T *x,
                                                     T *y) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
}
template <>
template <typename T>
1511
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::SCAL(int n,
                                                     const T alpha,
                                                     T *x) const {
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}
template <>
template <typename T>
1527
void Blas<phi::GPUContext>::SCAL(int n, const T alpha, T *x) const {
1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::VCOPY(int n,
                                                      const T *x,
                                                      T *y) const {
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}
template <>
template <typename T>
1542
void Blas<phi::GPUContext>::VCOPY(int n, const T *x, T *y) const {
1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::GEMV(bool trans_a,
                                                     int M,
                                                     int N,
                                                     T alpha,
                                                     const T *A,
                                                     const T *B,
                                                     T beta,
                                                     T *C) const {
  rocblas_operation cuTransA =
      !trans_a ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
}
template <>
template <typename T>
1566 1567 1568 1569 1570 1571 1572 1573
void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
                                 T *C) const {
1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587
  rocblas_operation cuTransA =
      !trans_a ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMV(
    bool trans_a,
    int M,
    int N,
1588 1589 1590 1591 1592
    phi::dtype::float16 alpha,
    const phi::dtype::float16 *A,
    const phi::dtype::float16 *B,
    phi::dtype::float16 beta,
    phi::dtype::float16 *C) const {
1593 1594
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
1595
    this->template GEMM<phi::dtype::float16>(
1596 1597
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
1598
    this->template GEMM<phi::dtype::float16>(
1599 1600 1601 1602 1603
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}
template <>
template <>
1604 1605 1606 1607 1608 1609 1610 1611
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
1612 1613
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
1614
    this->template GEMM<phi::dtype::float16>(
1615 1616
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
1617
    this->template GEMM<phi::dtype::float16>(
1618 1619 1620 1621 1622 1623 1624 1625 1626 1627
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::GEMV(
    bool trans_a,
    int M,
    int N,
1628 1629 1630 1631 1632
    phi::dtype::bfloat16 alpha,
    const phi::dtype::bfloat16 *A,
    const phi::dtype::bfloat16 *B,
    phi::dtype::bfloat16 beta,
    phi::dtype::bfloat16 *C) const {
1633 1634
  // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
  if (trans_a) {
1635
    this->template GEMM<phi::dtype::bfloat16>(
1636 1637
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
1638
    this->template GEMM<phi::dtype::bfloat16>(
1639 1640 1641 1642 1643
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}
template <>
template <>
1644 1645 1646 1647 1648 1649 1650 1651
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
1652 1653
  // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
  if (trans_a) {
1654
    this->template GEMM<phi::dtype::bfloat16>(
1655 1656
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
1657
    this->template GEMM<phi::dtype::bfloat16>(
1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
    T alpha,
    const T *A,
    const T *B,
    T beta,
    T *C,
    int batchCount,
    int64_t strideA,
    int64_t strideB) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  const int64_t strideC = M * N;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM_STRIDED_BATCH(handle,
                                  cuTransB,
                                  cuTransA,
                                  N,
                                  M,
                                  K,
                                  &alpha,
                                  B,
                                  ldb,
                                  strideB,
                                  A,
                                  lda,
                                  strideA,
                                  &beta,
                                  C,
                                  ldc,
                                  strideC,
                                  batchCount);
  });
}

template <>
template <typename T>
1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T *A,
                                        const T *B,
                                        T beta,
                                        T *C,
                                        int batchCount,
                                        int64_t strideA,
                                        int64_t strideB) const {
1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  const int64_t strideC = M * N;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM_STRIDED_BATCH(handle,
                                  cuTransB,
                                  cuTransA,
                                  N,
                                  M,
                                  K,
                                  &alpha,
                                  B,
                                  ldb,
                                  strideB,
                                  A,
                                  lda,
                                  strideA,
                                  &beta,
                                  C,
                                  ldc,
                                  strideC,
                                  batchCount);
  });
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
1769 1770 1771 1772 1773
    phi::dtype::bfloat16 alpha,
    const phi::dtype::bfloat16 *A,
    const phi::dtype::bfloat16 *B,
    phi::dtype::bfloat16 beta,
    phi::dtype::bfloat16 *C,
1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827
    int batchCount,
    int64_t strideA,
    int64_t strideB) const {
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  const int64_t strideC = M * N;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);
  rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

  context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_gemm_strided_batched_ex(
            handle,
            cuTransB,
            cuTransA,
            N,
            M,
            K,
            &h_alpha,
            B,
            rocblas_datatype_bf16_r,
            ldb,
            strideB,
            A,
            rocblas_datatype_bf16_r,
            lda,
            strideA,
            &h_beta,
            C,
            rocblas_datatype_bf16_r,
            ldc,
            strideC,
            C,
            rocblas_datatype_bf16_r,
            ldc,
            strideC,
            batchCount,
            rocblas_datatype_f32_r,
            algo,
            0,
            0));
  });
}

template <>
template <>
1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 *A,
                                               const phi::dtype::bfloat16 *B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 *C,
                                               int batchCount,
                                               int64_t strideA,
                                               int64_t strideB) const {
1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  const int64_t strideC = M * N;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);
  rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

  context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
        paddle::platform::dynload::rocblas_gemm_strided_batched_ex(
            handle,
            cuTransB,
            cuTransA,
            N,
            M,
            K,
            &h_alpha,
            B,
            rocblas_datatype_bf16_r,
            ldb,
            strideB,
            A,
            rocblas_datatype_bf16_r,
            lda,
            strideA,
            &h_beta,
            C,
            rocblas_datatype_bf16_r,
            ldc,
            strideC,
            C,
            rocblas_datatype_bf16_r,
            ldc,
            strideC,
            batchCount,
            rocblas_datatype_f32_r,
            algo,
            0,
            0));
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
    T alpha,
    const T **A,
    const T **B,
    T beta,
    T **C,
    int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <typename T>
1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T **A,
                                        const T **B,
                                        T beta,
                                        T **C,
                                        int batchCount) const {
1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
1937 1938 1939 1940 1941
    phi::dtype::float16 alpha,
    const phi::dtype::float16 **A,
    const phi::dtype::float16 **B,
    phi::dtype::float16 beta,
    phi::dtype::float16 **C,
1942 1943
    int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
1944
    this->template GEMM<phi::dtype::float16>(
1945 1946 1947 1948 1949
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}
template <>
template <>
1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::float16 alpha,
                                               const phi::dtype::float16 **A,
                                               const phi::dtype::float16 **B,
                                               phi::dtype::float16 beta,
                                               phi::dtype::float16 **C,
                                               int batchCount) const {
1961
  for (int k = 0; k < batchCount; ++k) {
1962
    this->template GEMM<phi::dtype::float16>(
1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
inline void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA,
    CBLAS_TRANSPOSE transB,
    int M,
    int N,
    int K,
1975 1976 1977 1978 1979
    phi::dtype::bfloat16 alpha,
    const phi::dtype::bfloat16 **A,
    const phi::dtype::bfloat16 **B,
    phi::dtype::bfloat16 beta,
    phi::dtype::bfloat16 **C,
1980 1981
    int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
1982
    this->template GEMM<phi::dtype::bfloat16>(
1983 1984 1985 1986 1987 1988
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 **A,
                                               const phi::dtype::bfloat16 **B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 **C,
                                               int batchCount) const {
2000
  for (int k = 0; k < batchCount; ++k) {
2001
    this->template GEMM<phi::dtype::bfloat16>(
2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side,
                                                     CBLAS_UPLO uplo,
                                                     CBLAS_TRANSPOSE transA,
                                                     CBLAS_DIAG diag,
                                                     int M,
                                                     int N,
                                                     T alpha,
                                                     const T *A,
                                                     int lda,
                                                     T *B,
                                                     int ldb) const {
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  rocblas_side cuSide =
      (side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
  rocblas_fill cuUplo =
      (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_diagonal cuDiag =
      (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::TRSM(
        handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb);
  });
}
template <>
template <typename T>
2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049
void Blas<phi::GPUContext>::TRSM(CBLAS_SIDE side,
                                 CBLAS_UPLO uplo,
                                 CBLAS_TRANSPOSE transA,
                                 CBLAS_DIAG diag,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 T *B,
                                 int ldb) const {
2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  rocblas_side cuSide =
      (side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
  rocblas_fill cuUplo =
      (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_diagonal cuDiag =
      (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::TRSM(
        handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedGETRF(
    int n, T **a, int *ipiv, int *info, int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}
template <>
template <typename T>
2079
void Blas<phi::GPUContext>::BatchedGETRF(
2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093
    int n, T **a, int *ipiv, int *info, int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedGETRI(
    int n, const T **a, const int *ipiv, T **a_inv, int *info, int batch_size)
    const {
  PADDLE_ENFORCE_NE(
      a_inv,
      a,
2094
      phi::errors::InvalidArgument(
2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv,
          a));
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}
template <>
template <typename T>
2106 2107 2108 2109 2110 2111
void Blas<phi::GPUContext>::BatchedGETRI(int n,
                                         const T **a,
                                         const int *ipiv,
                                         T **a_inv,
                                         int *info,
                                         int batch_size) const {
2112 2113 2114
  PADDLE_ENFORCE_NE(
      a_inv,
      a,
2115
      phi::errors::InvalidArgument(
2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv,
          a));
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedMatInv(
    int n, const T **a, T **a_inv, int *info, int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}
template <>
template <typename T>
2136
void Blas<phi::GPUContext>::BatchedMatInv(
2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165
    int n, const T **a, T **a_inv, int *info, int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedGETRS(
    CBLAS_TRANSPOSE trans,
    int n,
    int nrhs,
    const T **a,
    int lda,
    int *ipiv,
    T **b,
    int ldb,
    int *info,
    int batch_size) const {
  rocblas_operation cuTrans = (trans == CblasNoTrans)
                                  ? rocblas_operation_none
                                  : rocblas_operation_transpose;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRS_BATCH(
        handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size);
  });
}
template <>
template <typename T>
2166 2167 2168 2169 2170 2171 2172 2173 2174 2175
void Blas<phi::GPUContext>::BatchedGETRS(CBLAS_TRANSPOSE trans,
                                         int n,
                                         int nrhs,
                                         const T **a,
                                         int lda,
                                         int *ipiv,
                                         T **b,
                                         int ldb,
                                         int *info,
                                         int batch_size) const {
2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230
  rocblas_operation cuTrans = (trans == CblasNoTrans)
                                  ? rocblas_operation_none
                                  : rocblas_operation_transpose;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRS_BATCH(
        handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<paddle::platform::CUDADeviceContext>::BatchedTRSM(
    CBLAS_SIDE side,
    CBLAS_UPLO uplo,
    CBLAS_TRANSPOSE transA,
    CBLAS_DIAG diag,
    int M,
    int N,
    T alpha,
    const T **A,
    int lda,
    T **B,
    int ldb,
    int batch_size) const {
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  rocblas_side cuSide =
      (side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
  rocblas_fill cuUplo =
      (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_diagonal cuDiag =
      (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::TRSM_BATCH(handle,
                          cuSide,
                          cuUplo,
                          cuTransA,
                          cuDiag,
                          N,
                          M,
                          &alpha,
                          A,
                          lda,
                          B,
                          ldb,
                          batch_size);
  });
}
template <>
template <typename T>
2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242
void Blas<phi::GPUContext>::BatchedTRSM(CBLAS_SIDE side,
                                        CBLAS_UPLO uplo,
                                        CBLAS_TRANSPOSE transA,
                                        CBLAS_DIAG diag,
                                        int M,
                                        int N,
                                        T alpha,
                                        const T **A,
                                        int lda,
                                        T **B,
                                        int ldb,
                                        int batch_size) const {
2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  rocblas_side cuSide =
      (side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
  rocblas_fill cuUplo =
      (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_diagonal cuDiag =
      (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::TRSM_BATCH(handle,
                          cuSide,
                          cuUplo,
                          cuTransA,
                          cuDiag,
                          N,
                          M,
                          &alpha,
                          A,
                          lda,
                          B,
                          ldb,
                          batch_size);
  });
}

}  // namespace funcs
2274
}  // namespace phi