blas_impl.cu.h 70.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "gflags/gflags.h"
18 19
#include "glog/logging.h"

20
#include "paddle/phi/backends/dynload/cublas.h"
21
#include "paddle/phi/backends/gpu/gpu_context.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
23 24

DECLARE_bool(enable_cublas_tensor_op_math);
25
DECLARE_bool(gemm_use_half_precision_compute_type);
26

27
namespace phi {
28 29 30 31 32 33 34 35 36
namespace funcs {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
37
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...));
38 39 40 41
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
42
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...));
43 44 45 46
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
47
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...));
48 49 50 51
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
52
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...));
53 54 55 56
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
57
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...));
58 59 60 61 62 63
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_GPU_SUCCESS(
64
        phi::dynload::cublasSgemmStridedBatched(args...));
65
#else
66
    PADDLE_THROW(phi::errors::Unimplemented(
67 68 69 70 71 72 73
        "SgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
74
  static void GEMM_EX(phi::GPUContext *dev_ctx,
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const float *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const float *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
    VLOG(5) << "use_tensor_op_math: "
            << (dev_ctx->tensor_core_available() ? "True" : "False");
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle,
                                                             transa,
                                                             transb,
                                                             m,
                                                             n,
                                                             k,
                                                             alpha,
                                                             A,
                                                             Atype,
                                                             lda,
                                                             B,
                                                             Btype,
                                                             ldb,
                                                             beta,
                                                             C,
                                                             Ctype,
                                                             ldc));
115 116
    });
#else
117
    PADDLE_THROW(phi::errors::Unimplemented(
118 119 120 121 122 123
        "cublasSgemmEx is not supported on cuda <= 7.5"));
#endif
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
124
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...));
125 126 127 128
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
129
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...));
130 131 132 133
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
134
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...));
135 136 137 138
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
139
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...));
140 141 142 143
  }

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
144
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...));
145 146 147 148
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
149
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...));
150 151 152 153 154 155 156
  }
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
157
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...));
158 159 160 161
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
162
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...));
163 164 165 166
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
167
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...));
168 169 170 171
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
172
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...));
173 174 175 176
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
177
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...));
178 179 180 181 182 183
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_GPU_SUCCESS(
184
        phi::dynload::cublasDgemmStridedBatched(args...));
185
#else
186
    PADDLE_THROW(phi::errors::Unimplemented(
187 188 189 190 191 192 193
        "DgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW(
194
        phi::errors::Unimplemented("Currently there are not cublasDgemmEx."));
195 196 197 198
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
199
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...));
200 201 202 203
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
204
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...));
205 206 207 208
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
209
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...));
210 211 212 213
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
214
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...));
215 216 217 218
  }

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
219
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...));
220 221 222 223
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
224
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...));
225 226 227 228
  }
};

template <>
229 230
struct CUBlas<phi::dtype::float16> {
  using float16 = phi::dtype::float16;
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
                   const float16 *alpha,
                   const float16 *A,
                   int lda,
                   const float16 *B,
                   int ldb,
                   const float16 *beta,
                   float16 *C,
                   int ldc) {
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
    PADDLE_ENFORCE_GPU_SUCCESS(
        phi::dynload::cublasHgemm(handle,
                                  transa,
                                  transb,
                                  m,
                                  n,
                                  k,
                                  reinterpret_cast<const __half *>(alpha),
                                  reinterpret_cast<const __half *>(A),
                                  lda,
                                  reinterpret_cast<const __half *>(B),
                                  ldb,
                                  reinterpret_cast<const __half *>(beta),
                                  reinterpret_cast<__half *>(C),
                                  ldc));
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  }

  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb,
                                 int m,
                                 int n,
                                 int k,
                                 const float16 *alpha,
                                 const float16 *A,
                                 int lda,
                                 long long int strideA,  // NOLINT
                                 const float16 *B,       // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const float16 *beta,
                                 float16 *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
#if CUDA_VERSION >= 8000
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const __half *>(alpha),
        reinterpret_cast<const __half *>(A),
        lda,
        strideA,
        reinterpret_cast<const __half *>(B),
        ldb,
        strideB,
        reinterpret_cast<const __half *>(beta),
        reinterpret_cast<__half *>(C),
        ldc,
        strideC,
        batchCount));
301
#else
302
    PADDLE_THROW(phi::errors::Unimplemented(
303 304 305 306 307 308 309
        "HgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
310
  static void GEMM_EX(phi::GPUContext *dev_ctx,
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
359 360
    });
#else
361
    PADDLE_THROW(phi::errors::Unimplemented(
362 363 364 365 366 367
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

template <>
368
struct CUBlas<phi::dtype::complex<float>> {
369 370 371 372
  static void GEMV(cublasHandle_t handle,
                   cublasOperation_t transa,
                   int m,
                   int n,
373 374
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
375
                   int lda,
376
                   const phi::dtype::complex<float> *B,
377
                   int ldb,
378 379
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
380
                   int ldc) {
381
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv(
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc));
  }

  static void AXPY(cublasHandle_t handle,
                   int n,
398 399
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *X,
400
                   const int incX,
401
                   phi::dtype::complex<float> *Y,
402
                   const int incY) {
403
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy(
404 405 406 407 408 409 410 411 412
        handle,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(X),
        incX,
        reinterpret_cast<cuFloatComplex *>(Y),
        incY));
  }

413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb,
                                 int m,
                                 int n,
                                 int k,
                                 const phi::dtype::complex<float> *alpha,
                                 const phi::dtype::complex<float> *A,
                                 int lda,
                                 long long int strideA,                // NOLINT
                                 const phi::dtype::complex<float> *B,  // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const phi::dtype::complex<float> *beta,
                                 phi::dtype::complex<float> *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
431
#if CUDA_VERSION >= 8000
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        strideA,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        strideB,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc,
        strideC,
        batchCount));
451
#else
452
    PADDLE_THROW(phi::errors::Unimplemented(
453 454 455 456 457 458 459 460 461 462
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
463 464
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
465
                   int lda,
466
                   const phi::dtype::complex<float> *B,
467
                   int ldb,
468 469
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
470
                   int ldc) {
471
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm(
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc));
  }

  static void TRSM(cublasHandle_t handle,
                   cublasSideMode_t side,
                   cublasFillMode_t uplo,
                   cublasOperation_t transa,
                   cublasDiagType_t diag,
                   int m,
                   int n,
495 496
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
497
                   int lda,
498
                   phi::dtype::complex<float> *B,
499
                   int ldb) {
500
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm(
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<cuFloatComplex *>(B),
        ldb));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
518
  static void GEMM_EX(phi::GPUContext *dev_ctx,
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
567 568
    });
#else
569
    PADDLE_THROW(phi::errors::Unimplemented(
570 571 572 573 574 575 576 577 578 579 580
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }

  static void TRSM_BATCH(cublasHandle_t handle,
                         cublasSideMode_t side,
                         cublasFillMode_t uplo,
                         cublasOperation_t transa,
                         cublasDiagType_t diag,
                         int m,
                         int n,
581 582
                         const phi::dtype::complex<float> *alpha,
                         const phi::dtype::complex<float> **A,
583
                         int lda,
584
                         phi::dtype::complex<float> **B,
585 586
                         int ldb,
                         int batch_size) {
587
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched(
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex **>(A),
        lda,
        reinterpret_cast<cuFloatComplex **>(B),
        ldb,
        batch_size));
  }
};

template <>
605
struct CUBlas<phi::dtype::complex<double>> {
606 607 608 609
  static void GEMV(cublasHandle_t handle,
                   cublasOperation_t transa,
                   int m,
                   int n,
610 611
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
612
                   int lda,
613
                   const phi::dtype::complex<double> *B,
614
                   int ldb,
615 616
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
617
                   int ldc) {
618
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv(
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc));
  }

  static void AXPY(cublasHandle_t handle,
                   int n,
635 636
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *X,
637
                   const int incX,
638
                   phi::dtype::complex<double> *Y,
639
                   const int incY) {
640
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy(
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
        handle,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(X),
        incX,
        reinterpret_cast<cuDoubleComplex *>(Y),
        incY));
  }

  static void GEMM_STRIDED_BATCH(
      cublasHandle_t handle,
      cublasOperation_t transa,
      cublasOperation_t transb,
      int m,
      int n,
      int k,
657 658
      const phi::dtype::complex<double> *alpha,
      const phi::dtype::complex<double> *A,
659
      int lda,
660 661
      long long int strideA,                 // NOLINT
      const phi::dtype::complex<double> *B,  // NOLINT
662 663
      int ldb,
      long long int strideB,  // NOLINT
664 665
      const phi::dtype::complex<double> *beta,
      phi::dtype::complex<double> *C,
666 667 668 669
      int ldc,
      long long int strideC,  // NOLINT
      int batchCount) {
#if CUDA_VERSION >= 8000
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        strideA,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        strideB,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc,
        strideC,
        batchCount));
689
#else
690
    PADDLE_THROW(phi::errors::Unimplemented(
691 692 693 694 695 696 697 698 699 700
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
701 702
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
703
                   int lda,
704
                   const phi::dtype::complex<double> *B,
705
                   int ldb,
706 707
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
708
                   int ldc) {
709
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm(
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc));
  }

  static void TRSM(cublasHandle_t handle,
                   cublasSideMode_t side,
                   cublasFillMode_t uplo,
                   cublasOperation_t transa,
                   cublasDiagType_t diag,
                   int m,
                   int n,
733 734
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
735
                   int lda,
736
                   phi::dtype::complex<double> *B,
737
                   int ldb) {
738
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm(
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<cuDoubleComplex *>(B),
        ldb));
  }

  static void TRSM_BATCH(cublasHandle_t handle,
                         cublasSideMode_t side,
                         cublasFillMode_t uplo,
                         cublasOperation_t transa,
                         cublasDiagType_t diag,
                         int m,
                         int n,
760 761
                         const phi::dtype::complex<double> *alpha,
                         const phi::dtype::complex<double> **A,
762
                         int lda,
763
                         phi::dtype::complex<double> **B,
764 765
                         int ldb,
                         int batch_size) {
766
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched(
767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex **>(A),
        lda,
        reinterpret_cast<cuDoubleComplex **>(B),
        ldb,
        batch_size));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
785
  static void GEMM_EX(phi::GPUContext *dev_ctx,
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
834 835
    });
#else
836
    PADDLE_THROW(phi::errors::Unimplemented(
837 838 839 840 841 842 843
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

template <>
template <typename T>
844 845 846 847 848 849 850 851 852 853
void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                 CBLAS_TRANSPOSE transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
                                 T *C) const {
854 855 856 857 858 859 860 861 862 863 864
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
865
    auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
    CUBlas<T>::GEMM_EX(&cuda_ctx,
                       cuTransB,
                       cuTransA,
                       N,
                       M,
                       K,
                       &alpha,
                       B,
                       CUDA_R_32F,
                       ldb,
                       A,
                       CUDA_R_32F,
                       lda,
                       &beta,
                       C,
                       CUDA_R_32F,
                       N);
  } else {
#endif  // CUDA_VERSION >= 8000
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle,
                      cuTransB,
                      cuTransA,
                      N,
                      M,
                      K,
                      &alpha,
                      B,
                      ldb,
                      A,
                      lda,
                      &beta,
                      C,
                      N);
    });

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
909 910 911 912 913 914 915 916 917 918
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
919 920 921 922 923 924 925 926 927 928 929 930 931
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
932
      phi::errors::InvalidArgument(
933 934 935 936 937 938 939 940 941 942 943 944
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
945 946
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
947 948 949 950 951 952
                                       cuTransB,
                                       cuTransA,
                                       N,
                                       M,
                                       K,
                                       &h_alpha,
953 954
                                       B,
                                       CUDA_R_16F,
955
                                       ldb,
956 957
                                       A,
                                       CUDA_R_16F,
958 959
                                       lda,
                                       &h_beta,
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
                                       C,
                                       CUDA_R_16F,
                                       N,
                                       CUDA_R_32F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &h_alpha,
                                      h_B,
                                      ldb,
                                      h_A,
                                      lda,
                                      &h_beta,
                                      h_C,
                                      N);
982 983 984 985 986 987
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
988 989 990 991 992 993 994 995 996 997
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
#if CUDA_VERSION >= 11000
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      80,
1011
      phi::errors::InvalidArgument(
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
          "cublas bf16 gemm requires GPU compute capability >= 80,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  bool use_tensor_op_math = context_.tensor_core_available();
  if (use_tensor_op_math) {
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  }
  VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

  context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                          cuTransB,
                                                          cuTransA,
                                                          N,
                                                          M,
                                                          K,
                                                          &h_alpha,
                                                          B,
                                                          CUDA_R_16BF,
                                                          ldb,
                                                          A,
                                                          CUDA_R_16BF,
                                                          lda,
                                                          &h_beta,
                                                          C,
                                                          CUDA_R_16BF,
                                                          N,
                                                          CUDA_R_32F,
                                                          algo));
1046 1047 1048
  });
#else
  // raise error
1049
  PADDLE_THROW(phi::errors::Unimplemented(
1050 1051 1052 1053 1054 1055 1056
      "cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif  // CUDA_VERSION >= 11000
}

template <>
template <>
1057 1058 1059 1060 1061 1062 1063 1064 1065 1066
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<float> alpha,
                                        const phi::dtype::complex<float> *A,
                                        const phi::dtype::complex<float> *B,
                                        phi::dtype::complex<float> beta,
                                        phi::dtype::complex<float> *C) const {
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1080
      phi::errors::InvalidArgument(
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
1094 1095
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::complex<float>>::GEMM_EX(&cuda_ctx,
1096 1097 1098 1099 1100 1101
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
1102 1103
                                              B,
                                              CUDA_C_32F,
1104
                                              ldb,
1105 1106
                                              A,
                                              CUDA_C_32F,
1107 1108
                                              lda,
                                              &c_beta,
1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130
                                              C,
                                              CUDA_C_32F,
                                              N,
                                              CUDA_C_32F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::complex<float>>::GEMM(handle,
                                             cuTransB,
                                             cuTransA,
                                             N,
                                             M,
                                             K,
                                             &c_alpha,
                                             h_B,
                                             ldb,
                                             h_A,
                                             lda,
                                             &c_beta,
                                             h_C,
                                             N);
1131 1132 1133 1134 1135 1136
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<double> alpha,
                                        const phi::dtype::complex<double> *A,
                                        const phi::dtype::complex<double> *B,
                                        phi::dtype::complex<double> beta,
                                        phi::dtype::complex<double> *C) const {
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1160
      phi::errors::InvalidArgument(
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
1175
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
1176
  CUBlas<phi::dtype::complex<double>>::GEMM_EX(&cuda_ctx,
1177 1178 1179 1180 1181 1182
                                               cuTransB,
                                               cuTransA,
                                               N,
                                               M,
                                               K,
                                               &c_alpha,
1183 1184
                                               B,
                                               CUDA_C_64F,
1185
                                               ldb,
1186 1187
                                               A,
                                               CUDA_C_64F,
1188 1189
                                               lda,
                                               &c_beta,
1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
                                               C,
                                               CUDA_C_64F,
                                               N,
                                               CUDA_C_64F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::complex<double>>::GEMM(handle,
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
                                              h_B,
                                              ldb,
                                              h_A,
                                              lda,
                                              &c_beta,
                                              h_C,
                                              N);
1212 1213 1214 1215 1216 1217
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <typename T>
1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230
void Blas<phi::GPUContext>::GEMM(bool transA,
                                 bool transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 const T *B,
                                 int ldb,
                                 T beta,
                                 T *C,
                                 int ldc) const {
1231 1232 1233 1234 1235 1236 1237
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
1238
    auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282
    CUBlas<T>::GEMM_EX(&cuda_ctx,
                       cuTransB,
                       cuTransA,
                       N,
                       M,
                       K,
                       &alpha,
                       B,
                       CUDA_R_32F,
                       ldb,
                       A,
                       CUDA_R_32F,
                       lda,
                       &beta,
                       C,
                       CUDA_R_32F,
                       ldc);
  } else {
#endif  // CUDA_VERSION >= 8000

    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle,
                      cuTransB,
                      cuTransA,
                      N,
                      M,
                      K,
                      &alpha,
                      B,
                      ldb,
                      A,
                      lda,
                      &beta,
                      C,
                      ldc);
    });

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
inline void Blas<phi::GPUContext>::GEMM(bool transA,
                                        bool transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        int lda,
                                        const phi::dtype::float16 *B,
                                        int ldb,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C,
                                        int ldc) const {
1296 1297 1298 1299 1300 1301
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

  context_.CublasCall([&](cublasHandle_t handle) {
1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &alpha,
                                      B,
                                      ldb,
                                      A,
                                      lda,
                                      &beta,
                                      C,
                                      ldc);
1316 1317 1318
  });
}

1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386
template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
                                        bool transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        int lda,
                                        const phi::dtype::bfloat16 *B,
                                        int ldb,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C,
                                        int ldc) const {
#if CUDA_VERSION >= 11000
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      80,
      phi::errors::InvalidArgument(
          "cublas bf16 gemm requires GPU compute capability >= 80,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  bool use_tensor_op_math = context_.tensor_core_available();
  if (use_tensor_op_math) {
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  }
  VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

  context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                          cuTransB,
                                                          cuTransA,
                                                          N,
                                                          M,
                                                          K,
                                                          &h_alpha,
                                                          B,
                                                          CUDA_R_16BF,
                                                          ldb,
                                                          A,
                                                          CUDA_R_16BF,
                                                          lda,
                                                          &h_beta,
                                                          C,
                                                          CUDA_R_16BF,
                                                          ldc,
                                                          CUDA_R_32F,
                                                          algo));
  });
#else
  // raise error
  PADDLE_THROW(phi::errors::Unimplemented(
      "cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif  // CUDA_VERSION >= 11000
}

1387 1388
template <>
template <typename T>
1389
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
1390 1391 1392 1393 1394 1395 1396
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
}

template <>
template <typename T>
1397
void Blas<phi::GPUContext>::SCAL(int n, const T alpha, T *x) const {
1398 1399 1400 1401 1402 1403
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
1404
void Blas<phi::GPUContext>::VCOPY(int n, const T *x, T *y) const {
1405 1406 1407 1408 1409 1410
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

template <>
template <typename T>
1411 1412 1413 1414 1415 1416 1417
void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
1418 1419
                                 T *C) const {
  cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
1420

1421 1422 1423 1424
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
}
1425

1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444
template <>
template <>
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
    this->template GEMM<phi::dtype::float16>(
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
    this->template GEMM<phi::dtype::float16>(
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}
1445

1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463
template <>
template <>
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
  // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
  // it.
  if (trans_a) {
    this->template GEMM<phi::dtype::bfloat16>(
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
    this->template GEMM<phi::dtype::bfloat16>(
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
1464 1465 1466 1467 1468
  }
}

template <>
template <typename T>
1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T *A,
                                        const T *B,
                                        T beta,
                                        T *C,
                                        int batchCount,
                                        int64_t strideA,
                                        int64_t strideB) const {
1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

#if CUDA_VERSION >= 9010
  if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same<T, float>::value)) ||
1495
      std::is_same<T, phi::dtype::float16>::value) {
1496 1497 1498 1499 1500 1501 1502
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
    bool use_tensor_op_math = context_.tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
1503 1504
    VLOG(4) << "use_half_precision_compute_type: "
            << FLAGS_gemm_use_half_precision_compute_type;
1505 1506

    auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
S
sneaxiy 已提交
1507 1508 1509 1510 1511
#if CUDA_VERSION >= 11000
    auto compute_type = CUBLAS_COMPUTE_32F;
#else
    auto compute_type = CUDA_R_32F;
#endif
1512 1513 1514 1515 1516 1517 1518 1519 1520 1521

    float h_alpha = static_cast<float>(alpha);
    float h_beta = static_cast<float>(beta);
    void *a = static_cast<void *>(&h_alpha);
    void *b = static_cast<void *>(&h_beta);
    // set ComputeType as CUDA_R_32F for fp16, for better accuracy
    if (FLAGS_gemm_use_half_precision_compute_type == true &&
        std::is_same<T, phi::dtype::float16>::value) {
      a = static_cast<void *>(&alpha);
      b = static_cast<void *>(&beta);
S
sneaxiy 已提交
1522 1523 1524
#if CUDA_VERSION >= 11000
      compute_type = CUBLAS_COMPUTE_16F;
#else
1525
      compute_type = CUDA_R_16F;
S
sneaxiy 已提交
1526
#endif
1527 1528
    }

1529 1530
    context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553
          phi::dynload::cublasGemmStridedBatchedEx(handle,
                                                   cuTransB,
                                                   cuTransA,
                                                   N,
                                                   M,
                                                   K,
                                                   a,
                                                   B,
                                                   fp,
                                                   ldb,
                                                   strideB,
                                                   A,
                                                   fp,
                                                   lda,
                                                   strideA,
                                                   b,
                                                   C,
                                                   fp,
                                                   ldc,
                                                   strideC,
                                                   batchCount,
                                                   compute_type,
                                                   algo));
1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585
    });
  } else {
#endif  // CUDA_VERSION >= 9010

    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM_STRIDED_BATCH(handle,
                                    cuTransB,
                                    cuTransA,
                                    N,
                                    M,
                                    K,
                                    &alpha,
                                    B,
                                    ldb,
                                    strideB,
                                    A,
                                    lda,
                                    strideA,
                                    &beta,
                                    C,
                                    ldc,
                                    strideC,
                                    batchCount);
    });

#if CUDA_VERSION >= 9010
  }
#endif  // CUDA_VERSION >= 9010
}

template <>
template <>
1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 *A,
                                               const phi::dtype::bfloat16 *B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 *C,
                                               int batchCount,
                                               int64_t strideA,
                                               int64_t strideB) const {
1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
#if CUDA_VERSION >= 11000
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
  bool use_tensor_op_math = context_.tensor_core_available();
  if (use_tensor_op_math) {
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  }
  VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

  context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645
        phi::dynload::cublasGemmStridedBatchedEx(handle,
                                                 cuTransB,
                                                 cuTransA,
                                                 N,
                                                 M,
                                                 K,
                                                 &h_alpha,
                                                 B,
                                                 CUDA_R_16BF,
                                                 ldb,
                                                 strideB,
                                                 A,
                                                 CUDA_R_16BF,
                                                 lda,
                                                 strideA,
                                                 &h_beta,
                                                 C,
                                                 CUDA_R_16BF,
                                                 ldc,
                                                 strideC,
                                                 batchCount,
                                                 CUBLAS_COMPUTE_32F,
                                                 algo));
1646 1647 1648
  });
#else
  // raise error
1649
  PADDLE_THROW(phi::errors::Unimplemented(
1650 1651 1652 1653 1654 1655 1656
      "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
      "11"));
#endif  // CUDA_VERSION >= 11000
}

template <>
template <typename T>
1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T **A,
                                        const T **B,
                                        T beta,
                                        T **C,
                                        int batchCount) const {
1668 1669 1670 1671 1672 1673 1674 1675
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::float16 alpha,
                                               const phi::dtype::float16 **A,
                                               const phi::dtype::float16 **B,
                                               phi::dtype::float16 beta,
                                               phi::dtype::float16 **C,
                                               int batchCount) const {
1687
  for (int k = 0; k < batchCount; ++k) {
1688
    this->template GEMM<phi::dtype::float16>(
1689 1690 1691 1692 1693 1694
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 **A,
                                               const phi::dtype::bfloat16 **B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 **C,
                                               int batchCount) const {
1706
  for (int k = 0; k < batchCount; ++k) {
1707
    this->template GEMM<phi::dtype::bfloat16>(
1708 1709 1710 1711 1712 1713
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <typename T>
1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724
void Blas<phi::GPUContext>::TRSM(CBLAS_SIDE side,
                                 CBLAS_UPLO uplo,
                                 CBLAS_TRANSPOSE transA,
                                 CBLAS_DIAG diag,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 T *B,
                                 int ldb) const {
1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM(
        handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb);
  });
}

template <>
template <typename T>
1745
void Blas<phi::GPUContext>::BatchedGETRF(
1746 1747 1748 1749 1750 1751 1752 1753
    int n, T **a, int *ipiv, int *info, int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}

template <>
template <typename T>
1754 1755 1756 1757 1758 1759
void Blas<phi::GPUContext>::BatchedGETRI(int n,
                                         const T **a,
                                         const int *ipiv,
                                         T **a_inv,
                                         int *info,
                                         int batch_size) const {
1760 1761 1762
  PADDLE_ENFORCE_NE(
      a_inv,
      a,
1763
      phi::errors::InvalidArgument(
1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv,
          a));
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
1776
void Blas<phi::GPUContext>::BatchedMatInv(
1777 1778 1779 1780 1781 1782 1783 1784
    int n, const T **a, T **a_inv, int *info, int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
1785 1786 1787 1788 1789 1790 1791 1792 1793 1794
void Blas<phi::GPUContext>::BatchedGETRS(CBLAS_TRANSPOSE trans,
                                         int n,
                                         int nrhs,
                                         const T **a,
                                         int lda,
                                         int *ipiv,
                                         T **b,
                                         int ldb,
                                         int *info,
                                         int batch_size) const {
1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTrans =
      (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRS_BATCH(
        handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size);
  });
}

template <>
template <typename T>
1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817
void Blas<phi::GPUContext>::BatchedTRSM(CBLAS_SIDE side,
                                        CBLAS_UPLO uplo,
                                        CBLAS_TRANSPOSE transA,
                                        CBLAS_DIAG diag,
                                        int M,
                                        int N,
                                        T alpha,
                                        const T **A,
                                        int lda,
                                        T **B,
                                        int ldb,
                                        int batch_size) const {
1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM_BATCH(handle,
                          cuSide,
                          cuUplo,
                          cuTransA,
                          cuDiag,
                          N,
                          M,
                          &alpha,
                          A,
                          lda,
                          B,
                          ldb,
                          batch_size);
  });
}

}  // namespace funcs
1848
}  // namespace phi