blas_impl.cu.h 67.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/phi/backends/dynload/cublas.h"
18
#include "paddle/phi/backends/gpu/gpu_context.h"
19
#include "paddle/phi/kernels/funcs/math_function.h"
20 21

DECLARE_bool(enable_cublas_tensor_op_math);
22
DECLARE_bool(gemm_use_half_precision_compute_type);
23

24
namespace phi {
25 26 27 28 29 30 31 32 33
namespace funcs {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
34
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...));
35 36 37 38
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
39
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...));
40 41 42 43
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
44
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...));
45 46 47 48
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
49
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...));
50 51 52 53
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
54
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...));
55 56 57 58 59 60
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_GPU_SUCCESS(
61
        phi::dynload::cublasSgemmStridedBatched(args...));
62
#else
63
    PADDLE_THROW(phi::errors::Unimplemented(
64 65 66 67 68 69 70
        "SgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
71
  static void GEMM_EX(phi::GPUContext *dev_ctx,
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const float *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const float *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
    VLOG(5) << "use_tensor_op_math: "
            << (dev_ctx->tensor_core_available() ? "True" : "False");
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle,
                                                             transa,
                                                             transb,
                                                             m,
                                                             n,
                                                             k,
                                                             alpha,
                                                             A,
                                                             Atype,
                                                             lda,
                                                             B,
                                                             Btype,
                                                             ldb,
                                                             beta,
                                                             C,
                                                             Ctype,
                                                             ldc));
112 113
    });
#else
114
    PADDLE_THROW(phi::errors::Unimplemented(
115 116 117 118 119 120
        "cublasSgemmEx is not supported on cuda <= 7.5"));
#endif
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
121
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...));
122 123 124 125
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
126
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...));
127 128 129 130
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
131
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...));
132 133 134 135
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
136
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...));
137 138 139 140
  }

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
141
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...));
142 143 144 145
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
146
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...));
147 148 149 150 151 152 153
  }
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
154
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...));
155 156 157 158
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
159
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...));
160 161 162 163
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
164
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...));
165 166 167 168
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
169
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...));
170 171 172 173
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
174
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...));
175 176 177 178 179 180
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_GPU_SUCCESS(
181
        phi::dynload::cublasDgemmStridedBatched(args...));
182
#else
183
    PADDLE_THROW(phi::errors::Unimplemented(
184 185 186 187 188 189 190
        "DgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW(
191
        phi::errors::Unimplemented("Currently there are not cublasDgemmEx."));
192 193 194 195
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
196
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...));
197 198 199 200
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
201
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...));
202 203 204 205
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
206
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...));
207 208 209 210
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
211
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...));
212 213 214 215
  }

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
216
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...));
217 218 219 220
  }

  template <typename... ARGS>
  static void TRSM_BATCH(ARGS... args) {
221
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...));
222 223 224 225
  }
};

template <>
226 227
struct CUBlas<phi::dtype::float16> {
  using float16 = phi::dtype::float16;
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
                   const float16 *alpha,
                   const float16 *A,
                   int lda,
                   const float16 *B,
                   int ldb,
                   const float16 *beta,
                   float16 *C,
                   int ldc) {
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    PADDLE_ENFORCE_GPU_SUCCESS(
        phi::dynload::cublasHgemm(handle,
                                  transa,
                                  transb,
                                  m,
                                  n,
                                  k,
                                  reinterpret_cast<const __half *>(alpha),
                                  reinterpret_cast<const __half *>(A),
                                  lda,
                                  reinterpret_cast<const __half *>(B),
                                  ldb,
                                  reinterpret_cast<const __half *>(beta),
                                  reinterpret_cast<__half *>(C),
                                  ldc));
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  }

  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb,
                                 int m,
                                 int n,
                                 int k,
                                 const float16 *alpha,
                                 const float16 *A,
                                 int lda,
                                 long long int strideA,  // NOLINT
                                 const float16 *B,       // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const float16 *beta,
                                 float16 *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
#if CUDA_VERSION >= 8000
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const __half *>(alpha),
        reinterpret_cast<const __half *>(A),
        lda,
        strideA,
        reinterpret_cast<const __half *>(B),
        ldb,
        strideB,
        reinterpret_cast<const __half *>(beta),
        reinterpret_cast<__half *>(C),
        ldc,
        strideC,
        batchCount));
298
#else
299
    PADDLE_THROW(phi::errors::Unimplemented(
300 301 302 303 304 305 306
        "HgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
307
  static void GEMM_EX(phi::GPUContext *dev_ctx,
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
356 357
    });
#else
358
    PADDLE_THROW(phi::errors::Unimplemented(
359 360 361 362 363 364
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

template <>
365
struct CUBlas<phi::dtype::complex<float>> {
366 367 368 369
  static void GEMV(cublasHandle_t handle,
                   cublasOperation_t transa,
                   int m,
                   int n,
370 371
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
372
                   int lda,
373
                   const phi::dtype::complex<float> *B,
374
                   int ldb,
375 376
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
377
                   int ldc) {
378
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv(
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc));
  }

  static void AXPY(cublasHandle_t handle,
                   int n,
395 396
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *X,
397
                   const int incX,
398
                   phi::dtype::complex<float> *Y,
399
                   const int incY) {
400
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy(
401 402 403 404 405 406 407 408 409
        handle,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(X),
        incX,
        reinterpret_cast<cuFloatComplex *>(Y),
        incY));
  }

410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb,
                                 int m,
                                 int n,
                                 int k,
                                 const phi::dtype::complex<float> *alpha,
                                 const phi::dtype::complex<float> *A,
                                 int lda,
                                 long long int strideA,                // NOLINT
                                 const phi::dtype::complex<float> *B,  // NOLINT
                                 int ldb,
                                 long long int strideB,  // NOLINT
                                 const phi::dtype::complex<float> *beta,
                                 phi::dtype::complex<float> *C,
                                 int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
428
#if CUDA_VERSION >= 8000
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        strideA,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        strideB,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc,
        strideC,
        batchCount));
448
#else
449
    PADDLE_THROW(phi::errors::Unimplemented(
450 451 452 453 454 455 456 457 458 459
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
460 461
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
462
                   int lda,
463
                   const phi::dtype::complex<float> *B,
464
                   int ldb,
465 466
                   const phi::dtype::complex<float> *beta,
                   phi::dtype::complex<float> *C,
467
                   int ldc) {
468
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm(
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<const cuFloatComplex *>(B),
        ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C),
        ldc));
  }

  static void TRSM(cublasHandle_t handle,
                   cublasSideMode_t side,
                   cublasFillMode_t uplo,
                   cublasOperation_t transa,
                   cublasDiagType_t diag,
                   int m,
                   int n,
492 493
                   const phi::dtype::complex<float> *alpha,
                   const phi::dtype::complex<float> *A,
494
                   int lda,
495
                   phi::dtype::complex<float> *B,
496
                   int ldb) {
497
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm(
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A),
        lda,
        reinterpret_cast<cuFloatComplex *>(B),
        ldb));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
515
  static void GEMM_EX(phi::GPUContext *dev_ctx,
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
564 565
    });
#else
566
    PADDLE_THROW(phi::errors::Unimplemented(
567 568 569 570 571 572 573 574 575 576 577
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }

  static void TRSM_BATCH(cublasHandle_t handle,
                         cublasSideMode_t side,
                         cublasFillMode_t uplo,
                         cublasOperation_t transa,
                         cublasDiagType_t diag,
                         int m,
                         int n,
578 579
                         const phi::dtype::complex<float> *alpha,
                         const phi::dtype::complex<float> **A,
580
                         int lda,
581
                         phi::dtype::complex<float> **B,
582 583
                         int ldb,
                         int batch_size) {
584
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched(
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex **>(A),
        lda,
        reinterpret_cast<cuFloatComplex **>(B),
        ldb,
        batch_size));
  }
};

template <>
602
struct CUBlas<phi::dtype::complex<double>> {
603 604 605 606
  static void GEMV(cublasHandle_t handle,
                   cublasOperation_t transa,
                   int m,
                   int n,
607 608
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
609
                   int lda,
610
                   const phi::dtype::complex<double> *B,
611
                   int ldb,
612 613
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
614
                   int ldc) {
615
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv(
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
        handle,
        transa,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc));
  }

  static void AXPY(cublasHandle_t handle,
                   int n,
632 633
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *X,
634
                   const int incX,
635
                   phi::dtype::complex<double> *Y,
636
                   const int incY) {
637
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy(
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
        handle,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(X),
        incX,
        reinterpret_cast<cuDoubleComplex *>(Y),
        incY));
  }

  static void GEMM_STRIDED_BATCH(
      cublasHandle_t handle,
      cublasOperation_t transa,
      cublasOperation_t transb,
      int m,
      int n,
      int k,
654 655
      const phi::dtype::complex<double> *alpha,
      const phi::dtype::complex<double> *A,
656
      int lda,
657 658
      long long int strideA,                 // NOLINT
      const phi::dtype::complex<double> *B,  // NOLINT
659 660
      int ldb,
      long long int strideB,  // NOLINT
661 662
      const phi::dtype::complex<double> *beta,
      phi::dtype::complex<double> *C,
663 664 665 666
      int ldc,
      long long int strideC,  // NOLINT
      int batchCount) {
#if CUDA_VERSION >= 8000
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched(
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        strideA,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        strideB,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc,
        strideC,
        batchCount));
686
#else
687
    PADDLE_THROW(phi::errors::Unimplemented(
688 689 690 691 692 693 694 695 696 697
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
698 699
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
700
                   int lda,
701
                   const phi::dtype::complex<double> *B,
702
                   int ldb,
703 704
                   const phi::dtype::complex<double> *beta,
                   phi::dtype::complex<double> *C,
705
                   int ldc) {
706
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm(
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
        handle,
        transa,
        transb,
        m,
        n,
        k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<const cuDoubleComplex *>(B),
        ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C),
        ldc));
  }

  static void TRSM(cublasHandle_t handle,
                   cublasSideMode_t side,
                   cublasFillMode_t uplo,
                   cublasOperation_t transa,
                   cublasDiagType_t diag,
                   int m,
                   int n,
730 731
                   const phi::dtype::complex<double> *alpha,
                   const phi::dtype::complex<double> *A,
732
                   int lda,
733
                   phi::dtype::complex<double> *B,
734
                   int ldb) {
735
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm(
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A),
        lda,
        reinterpret_cast<cuDoubleComplex *>(B),
        ldb));
  }

  static void TRSM_BATCH(cublasHandle_t handle,
                         cublasSideMode_t side,
                         cublasFillMode_t uplo,
                         cublasOperation_t transa,
                         cublasDiagType_t diag,
                         int m,
                         int n,
757 758
                         const phi::dtype::complex<double> *alpha,
                         const phi::dtype::complex<double> **A,
759
                         int lda,
760
                         phi::dtype::complex<double> **B,
761 762
                         int ldb,
                         int batch_size) {
763
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched(
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
        handle,
        side,
        uplo,
        transa,
        diag,
        m,
        n,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex **>(A),
        lda,
        reinterpret_cast<cuDoubleComplex **>(B),
        ldb,
        batch_size));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
782
  static void GEMM_EX(phi::GPUContext *dev_ctx,
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811
                      cublasOperation_t transa,
                      cublasOperation_t transb,
                      int m,
                      int n,
                      int k,
                      const void *alpha,
                      const void *A,
                      cudaDataType_t Atype,
                      int lda,
                      const void *B,
                      cudaDataType_t Btype,
                      int ldb,
                      const void *beta,
                      void *C,
                      cudaDataType_t Ctype,
                      int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830
      PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                            transa,
                                                            transb,
                                                            m,
                                                            n,
                                                            k,
                                                            alpha,
                                                            A,
                                                            Atype,
                                                            lda,
                                                            B,
                                                            Btype,
                                                            ldb,
                                                            beta,
                                                            C,
                                                            Ctype,
                                                            ldc,
                                                            computeType,
                                                            algo));
831 832
    });
#else
833
    PADDLE_THROW(phi::errors::Unimplemented(
834 835 836 837 838 839 840
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

template <>
template <typename T>
841 842 843 844 845 846 847 848 849 850
void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                 CBLAS_TRANSPOSE transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
                                 T *C) const {
851 852 853 854 855 856 857 858 859 860 861
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
862
    auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
    CUBlas<T>::GEMM_EX(&cuda_ctx,
                       cuTransB,
                       cuTransA,
                       N,
                       M,
                       K,
                       &alpha,
                       B,
                       CUDA_R_32F,
                       ldb,
                       A,
                       CUDA_R_32F,
                       lda,
                       &beta,
                       C,
                       CUDA_R_32F,
                       N);
  } else {
#endif  // CUDA_VERSION >= 8000
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle,
                      cuTransB,
                      cuTransA,
                      N,
                      M,
                      K,
                      &alpha,
                      B,
                      ldb,
                      A,
                      lda,
                      &beta,
                      C,
                      N);
    });

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
906 907 908 909 910 911 912 913 914 915
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
916 917 918 919 920 921 922 923 924 925 926 927 928
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
929
      phi::errors::InvalidArgument(
930 931 932 933 934 935 936 937 938 939 940 941
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
942 943
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::float16>::GEMM_EX(&cuda_ctx,
944 945 946 947 948 949
                                       cuTransB,
                                       cuTransA,
                                       N,
                                       M,
                                       K,
                                       &h_alpha,
950 951
                                       B,
                                       CUDA_R_16F,
952
                                       ldb,
953 954
                                       A,
                                       CUDA_R_16F,
955 956
                                       lda,
                                       &h_beta,
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978
                                       C,
                                       CUDA_R_16F,
                                       N,
                                       CUDA_R_32F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &h_alpha,
                                      h_B,
                                      ldb,
                                      h_A,
                                      lda,
                                      &h_beta,
                                      h_C,
                                      N);
979 980 981 982 983 984
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
985 986 987 988 989 990 991 992 993 994
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007
#if CUDA_VERSION >= 11000
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      80,
1008
      phi::errors::InvalidArgument(
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
          "cublas bf16 gemm requires GPU compute capability >= 80,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  bool use_tensor_op_math = context_.tensor_core_available();
  if (use_tensor_op_math) {
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  }
  VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

  context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
    PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
                                                          cuTransB,
                                                          cuTransA,
                                                          N,
                                                          M,
                                                          K,
                                                          &h_alpha,
                                                          B,
                                                          CUDA_R_16BF,
                                                          ldb,
                                                          A,
                                                          CUDA_R_16BF,
                                                          lda,
                                                          &h_beta,
                                                          C,
                                                          CUDA_R_16BF,
                                                          N,
                                                          CUDA_R_32F,
                                                          algo));
1043 1044 1045
  });
#else
  // raise error
1046
  PADDLE_THROW(phi::errors::Unimplemented(
1047 1048 1049 1050 1051 1052 1053
      "cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif  // CUDA_VERSION >= 11000
}

template <>
template <>
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<float> alpha,
                                        const phi::dtype::complex<float> *A,
                                        const phi::dtype::complex<float> *B,
                                        phi::dtype::complex<float> beta,
                                        phi::dtype::complex<float> *C) const {
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1077
      phi::errors::InvalidArgument(
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
1091 1092
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
  CUBlas<phi::dtype::complex<float>>::GEMM_EX(&cuda_ctx,
1093 1094 1095 1096 1097 1098
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
1099 1100
                                              B,
                                              CUDA_C_32F,
1101
                                              ldb,
1102 1103
                                              A,
                                              CUDA_C_32F,
1104 1105
                                              lda,
                                              &c_beta,
1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
                                              C,
                                              CUDA_C_32F,
                                              N,
                                              CUDA_C_32F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::complex<float>>::GEMM(handle,
                                             cuTransB,
                                             cuTransA,
                                             N,
                                             M,
                                             K,
                                             &c_alpha,
                                             h_B,
                                             ldb,
                                             h_A,
                                             lda,
                                             &c_beta,
                                             h_C,
                                             N);
1128 1129 1130 1131 1132 1133
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
1134 1135 1136 1137 1138 1139 1140 1141 1142 1143
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::complex<double> alpha,
                                        const phi::dtype::complex<double> *A,
                                        const phi::dtype::complex<double> *B,
                                        phi::dtype::complex<double> beta,
                                        phi::dtype::complex<double> *C) const {
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(),
      53,
1157
      phi::errors::InvalidArgument(
1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
1172
  auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
1173
  CUBlas<phi::dtype::complex<double>>::GEMM_EX(&cuda_ctx,
1174 1175 1176 1177 1178 1179
                                               cuTransB,
                                               cuTransA,
                                               N,
                                               M,
                                               K,
                                               &c_alpha,
1180 1181
                                               B,
                                               CUDA_C_64F,
1182
                                               ldb,
1183 1184
                                               A,
                                               CUDA_C_64F,
1185 1186
                                               lda,
                                               &c_beta,
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
                                               C,
                                               CUDA_C_64F,
                                               N,
                                               CUDA_C_64F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<phi::dtype::complex<double>>::GEMM(handle,
                                              cuTransB,
                                              cuTransA,
                                              N,
                                              M,
                                              K,
                                              &c_alpha,
                                              h_B,
                                              ldb,
                                              h_A,
                                              lda,
                                              &c_beta,
                                              h_C,
                                              N);
1209 1210 1211 1212 1213 1214
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <typename T>
1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
void Blas<phi::GPUContext>::GEMM(bool transA,
                                 bool transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 const T *B,
                                 int ldb,
                                 T beta,
                                 T *C,
                                 int ldc) const {
1228 1229 1230 1231 1232 1233 1234
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
1235
    auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279
    CUBlas<T>::GEMM_EX(&cuda_ctx,
                       cuTransB,
                       cuTransA,
                       N,
                       M,
                       K,
                       &alpha,
                       B,
                       CUDA_R_32F,
                       ldb,
                       A,
                       CUDA_R_32F,
                       lda,
                       &beta,
                       C,
                       CUDA_R_32F,
                       ldc);
  } else {
#endif  // CUDA_VERSION >= 8000

    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle,
                      cuTransB,
                      cuTransA,
                      N,
                      M,
                      K,
                      &alpha,
                      B,
                      ldb,
                      A,
                      lda,
                      &beta,
                      C,
                      ldc);
    });

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292
inline void Blas<phi::GPUContext>::GEMM(bool transA,
                                        bool transB,
                                        int M,
                                        int N,
                                        int K,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        int lda,
                                        const phi::dtype::float16 *B,
                                        int ldb,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C,
                                        int ldc) const {
1293 1294 1295 1296 1297 1298
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

  context_.CublasCall([&](cublasHandle_t handle) {
1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312
    CUBlas<phi::dtype::float16>::GEMM(handle,
                                      cuTransB,
                                      cuTransA,
                                      N,
                                      M,
                                      K,
                                      &alpha,
                                      B,
                                      ldb,
                                      A,
                                      lda,
                                      &beta,
                                      C,
                                      ldc);
1313 1314 1315 1316 1317
  });
}

template <>
template <typename T>
1318
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
1319 1320 1321 1322 1323 1324 1325
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
}

template <>
template <typename T>
1326
void Blas<phi::GPUContext>::SCAL(int n, const T alpha, T *x) const {
1327 1328 1329 1330 1331 1332
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
1333
void Blas<phi::GPUContext>::VCOPY(int n, const T *x, T *y) const {
1334 1335 1336 1337 1338 1339
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

template <>
template <typename T>
1340 1341 1342 1343 1344 1345 1346
void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
1347 1348
                                 T *C) const {
  cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
1349

1350 1351 1352 1353
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
}
1354

1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373
template <>
template <>
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::float16 alpha,
                                        const phi::dtype::float16 *A,
                                        const phi::dtype::float16 *B,
                                        phi::dtype::float16 beta,
                                        phi::dtype::float16 *C) const {
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
    this->template GEMM<phi::dtype::float16>(
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
    this->template GEMM<phi::dtype::float16>(
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
  }
}
1374

1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
template <>
template <>
inline void Blas<phi::GPUContext>::GEMV(bool trans_a,
                                        int M,
                                        int N,
                                        phi::dtype::bfloat16 alpha,
                                        const phi::dtype::bfloat16 *A,
                                        const phi::dtype::bfloat16 *B,
                                        phi::dtype::bfloat16 beta,
                                        phi::dtype::bfloat16 *C) const {
  // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
  // it.
  if (trans_a) {
    this->template GEMM<phi::dtype::bfloat16>(
        CblasNoTrans, CblasNoTrans, 1, N, M, alpha, B, A, beta, C);
  } else {
    this->template GEMM<phi::dtype::bfloat16>(
        CblasNoTrans, CblasNoTrans, M, 1, N, alpha, A, B, beta, C);
1393 1394 1395 1396 1397
  }
}

template <>
template <typename T>
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T *A,
                                        const T *B,
                                        T beta,
                                        T *C,
                                        int batchCount,
                                        int64_t strideA,
                                        int64_t strideB) const {
1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

#if CUDA_VERSION >= 9010
  if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same<T, float>::value)) ||
1424
      std::is_same<T, phi::dtype::float16>::value) {
1425 1426 1427 1428 1429 1430 1431
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
    bool use_tensor_op_math = context_.tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
1432 1433
    VLOG(4) << "use_half_precision_compute_type: "
            << FLAGS_gemm_use_half_precision_compute_type;
1434 1435

    auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
S
sneaxiy 已提交
1436 1437 1438 1439 1440
#if CUDA_VERSION >= 11000
    auto compute_type = CUBLAS_COMPUTE_32F;
#else
    auto compute_type = CUDA_R_32F;
#endif
1441 1442 1443 1444 1445 1446 1447 1448 1449 1450

    float h_alpha = static_cast<float>(alpha);
    float h_beta = static_cast<float>(beta);
    void *a = static_cast<void *>(&h_alpha);
    void *b = static_cast<void *>(&h_beta);
    // set ComputeType as CUDA_R_32F for fp16, for better accuracy
    if (FLAGS_gemm_use_half_precision_compute_type == true &&
        std::is_same<T, phi::dtype::float16>::value) {
      a = static_cast<void *>(&alpha);
      b = static_cast<void *>(&beta);
S
sneaxiy 已提交
1451 1452 1453
#if CUDA_VERSION >= 11000
      compute_type = CUBLAS_COMPUTE_16F;
#else
1454
      compute_type = CUDA_R_16F;
S
sneaxiy 已提交
1455
#endif
1456 1457
    }

1458 1459
    context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
      PADDLE_ENFORCE_GPU_SUCCESS(
1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482
          phi::dynload::cublasGemmStridedBatchedEx(handle,
                                                   cuTransB,
                                                   cuTransA,
                                                   N,
                                                   M,
                                                   K,
                                                   a,
                                                   B,
                                                   fp,
                                                   ldb,
                                                   strideB,
                                                   A,
                                                   fp,
                                                   lda,
                                                   strideA,
                                                   b,
                                                   C,
                                                   fp,
                                                   ldc,
                                                   strideC,
                                                   batchCount,
                                                   compute_type,
                                                   algo));
1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514
    });
  } else {
#endif  // CUDA_VERSION >= 9010

    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM_STRIDED_BATCH(handle,
                                    cuTransB,
                                    cuTransA,
                                    N,
                                    M,
                                    K,
                                    &alpha,
                                    B,
                                    ldb,
                                    strideB,
                                    A,
                                    lda,
                                    strideA,
                                    &beta,
                                    C,
                                    ldc,
                                    strideC,
                                    batchCount);
    });

#if CUDA_VERSION >= 9010
  }
#endif  // CUDA_VERSION >= 9010
}

template <>
template <>
1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 *A,
                                               const phi::dtype::bfloat16 *B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 *C,
                                               int batchCount,
                                               int64_t strideA,
                                               int64_t strideB) const {
1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551
#if CUDA_VERSION >= 11000
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
  bool use_tensor_op_math = context_.tensor_core_available();
  if (use_tensor_op_math) {
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  }
  VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

  context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
    PADDLE_ENFORCE_GPU_SUCCESS(
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574
        phi::dynload::cublasGemmStridedBatchedEx(handle,
                                                 cuTransB,
                                                 cuTransA,
                                                 N,
                                                 M,
                                                 K,
                                                 &h_alpha,
                                                 B,
                                                 CUDA_R_16BF,
                                                 ldb,
                                                 strideB,
                                                 A,
                                                 CUDA_R_16BF,
                                                 lda,
                                                 strideA,
                                                 &h_beta,
                                                 C,
                                                 CUDA_R_16BF,
                                                 ldc,
                                                 strideC,
                                                 batchCount,
                                                 CUBLAS_COMPUTE_32F,
                                                 algo));
1575 1576 1577
  });
#else
  // raise error
1578
  PADDLE_THROW(phi::errors::Unimplemented(
1579 1580 1581 1582 1583 1584 1585
      "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
      "11"));
#endif  // CUDA_VERSION >= 11000
}

template <>
template <typename T>
1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                        CBLAS_TRANSPOSE transB,
                                        int M,
                                        int N,
                                        int K,
                                        T alpha,
                                        const T **A,
                                        const T **B,
                                        T beta,
                                        T **C,
                                        int batchCount) const {
1597 1598 1599 1600 1601 1602 1603 1604
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::float16 alpha,
                                               const phi::dtype::float16 **A,
                                               const phi::dtype::float16 **B,
                                               phi::dtype::float16 beta,
                                               phi::dtype::float16 **C,
                                               int batchCount) const {
1616
  for (int k = 0; k < batchCount; ++k) {
1617
    this->template GEMM<phi::dtype::float16>(
1618 1619 1620 1621 1622 1623
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <>
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
                                               CBLAS_TRANSPOSE transB,
                                               int M,
                                               int N,
                                               int K,
                                               phi::dtype::bfloat16 alpha,
                                               const phi::dtype::bfloat16 **A,
                                               const phi::dtype::bfloat16 **B,
                                               phi::dtype::bfloat16 beta,
                                               phi::dtype::bfloat16 **C,
                                               int batchCount) const {
1635
  for (int k = 0; k < batchCount; ++k) {
1636
    this->template GEMM<phi::dtype::bfloat16>(
1637 1638 1639 1640 1641 1642
        transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
  }
}

template <>
template <typename T>
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
void Blas<phi::GPUContext>::TRSM(CBLAS_SIDE side,
                                 CBLAS_UPLO uplo,
                                 CBLAS_TRANSPOSE transA,
                                 CBLAS_DIAG diag,
                                 int M,
                                 int N,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 T *B,
                                 int ldb) const {
1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM(
        handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, lda, B, ldb);
  });
}

template <>
template <typename T>
1674
void Blas<phi::GPUContext>::BatchedGETRF(
1675 1676 1677 1678 1679 1680 1681 1682
    int n, T **a, int *ipiv, int *info, int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}

template <>
template <typename T>
1683 1684 1685 1686 1687 1688
void Blas<phi::GPUContext>::BatchedGETRI(int n,
                                         const T **a,
                                         const int *ipiv,
                                         T **a_inv,
                                         int *info,
                                         int batch_size) const {
1689 1690 1691
  PADDLE_ENFORCE_NE(
      a_inv,
      a,
1692
      phi::errors::InvalidArgument(
1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv,
          a));
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
1705
void Blas<phi::GPUContext>::BatchedMatInv(
1706 1707 1708 1709 1710 1711 1712 1713
    int n, const T **a, T **a_inv, int *info, int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
1714 1715 1716 1717 1718 1719 1720 1721 1722 1723
void Blas<phi::GPUContext>::BatchedGETRS(CBLAS_TRANSPOSE trans,
                                         int n,
                                         int nrhs,
                                         const T **a,
                                         int lda,
                                         int *ipiv,
                                         T **b,
                                         int ldb,
                                         int *info,
                                         int batch_size) const {
1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTrans =
      (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRS_BATCH(
        handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, batch_size);
  });
}

template <>
template <typename T>
1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746
void Blas<phi::GPUContext>::BatchedTRSM(CBLAS_SIDE side,
                                        CBLAS_UPLO uplo,
                                        CBLAS_TRANSPOSE transA,
                                        CBLAS_DIAG diag,
                                        int M,
                                        int N,
                                        T alpha,
                                        const T **A,
                                        int lda,
                                        T **B,
                                        int ldb,
                                        int batch_size) const {
1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM_BATCH(handle,
                          cuSide,
                          cuUplo,
                          cuTransA,
                          cuDiag,
                          N,
                          M,
                          &alpha,
                          A,
                          lda,
                          B,
                          ldb,
                          batch_size);
  });
}

}  // namespace funcs
1777
}  // namespace phi