blas_impl.cu.h 36.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
19

20 21 22
#include "paddle/fluid/platform/gpu_info.h"

DECLARE_bool(enable_cublas_tensor_op_math);
Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33 34

namespace paddle {
namespace operators {
namespace math {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
35
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...));
Y
Yu Yang 已提交
36
  }
Y
Yu Yang 已提交
37 38 39

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
40
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
Y
Yu Yang 已提交
41 42
  }

43 44
  template <typename... ARGS>
  static void SCAL(ARGS... args) {
45
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSscal(args...));
46 47 48 49
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
50
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasScopy(args...));
51 52
  }

Y
Yu Yang 已提交
53 54
  template <typename... ARGS>
  static void GEMV(ARGS... args) {
55
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
Y
Yu Yang 已提交
56 57 58
  }

  template <typename... ARGS>
59
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
60
#if CUDA_VERSION >= 8000
61 62
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSgemmStridedBatched(args...));
Y
Yu Yang 已提交
63
#else
64 65
    PADDLE_THROW(platform::errors::Unimplemented(
        "SgemmStridedBatched is not supported on cuda <= 7.5"));
66 67 68 69 70 71 72 73 74 75 76 77
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const float *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const float *beta, void *C,
                      cudaDataType_t Ctype, int ldc) {
78 79 80
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
81
#if CUDA_VERSION >= 8000
82 83 84
    VLOG(5) << "use_tensor_op_math: "
            << (dev_ctx->tensor_core_available() ? "True" : "False");
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
85
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx(
86 87 88
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc));
    });
89
#else
90 91
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasSgemmEx is not supported on cuda <= 7.5"));
Y
Yu Yang 已提交
92 93
#endif
  }
G
Guo Sheng 已提交
94 95 96 97 98

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsm(args...));
  }
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSgetrfBatched(args...));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSgetriBatched(args...));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSmatinvBatched(args...));
  }
117 118 119 120 121 122

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSgetrsBatched(args...));
  }
Y
Yu Yang 已提交
123 124 125 126 127 128
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
129
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...));
Y
Yu Yang 已提交
130
  }
Y
Yu Yang 已提交
131 132 133

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
134
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
Y
Yu Yang 已提交
135 136
  }

137 138
  template <typename... ARGS>
  static void SCAL(ARGS... args) {
139
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDscal(args...));
140 141 142 143
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
144
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDcopy(args...));
145 146
  }

Y
Yu Yang 已提交
147 148
  template <typename... ARGS>
  static void GEMV(ARGS... args) {
149
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
Y
Yu Yang 已提交
150 151 152
  }

  template <typename... ARGS>
153
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
154
#if CUDA_VERSION >= 8000
155 156
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDgemmStridedBatched(args...));
Y
Yu Yang 已提交
157
#else
158 159
    PADDLE_THROW(platform::errors::Unimplemented(
        "DgemmStridedBatched is not supported on cuda <= 7.5"));
Y
Yu Yang 已提交
160 161
#endif
  }
162 163 164

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
165 166
    PADDLE_THROW(platform::errors::Unimplemented(
        "Currently there are not cublasDgemmEx."));
167
  }
G
Guo Sheng 已提交
168 169 170 171 172

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsm(args...));
  }
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDgetrfBatched(args...));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDgetriBatched(args...));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDmatinvBatched(args...));
  }
191 192 193 194 195 196

  template <typename... ARGS>
  static void GETRS_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDgetrsBatched(args...));
  }
Y
Yu Yang 已提交
197 198 199 200
};

template <>
struct CUBlas<platform::float16> {
Y
Yu Yang 已提交
201 202 203 204 205 206 207
  using float16 = platform::float16;

  static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
                   cublasOperation_t transb, int m, int n, int k,
                   const float16 *alpha, const float16 *A, int lda,
                   const float16 *B, int ldb, const float16 *beta, float16 *C,
                   int ldc) {
208
    PADDLE_ENFORCE_CUDA_SUCCESS(
Y
Yu Yang 已提交
209 210 211 212 213 214
        platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
                                       reinterpret_cast<const __half *>(alpha),
                                       reinterpret_cast<const __half *>(A), lda,
                                       reinterpret_cast<const __half *>(B), ldb,
                                       reinterpret_cast<const __half *>(beta),
                                       reinterpret_cast<__half *>(C), ldc));
Y
Yu Yang 已提交
215
  }
Y
Yu Yang 已提交
216

217 218 219 220 221 222 223 224 225 226
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb, int m, int n, int k,
                                 const float16 *alpha, const float16 *A,
                                 int lda, long long int strideA,  // NOLINT
                                 const float16 *B,                // NOLINT
                                 int ldb, long long int strideB,  // NOLINT
                                 const float16 *beta, float16 *C, int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
Y
Yu Yang 已提交
227
#if CUDA_VERSION >= 8000
228
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched(
Y
yuyang18 已提交
229 230 231 232 233 234
        handle, transa, transb, m, n, k,
        reinterpret_cast<const __half *>(alpha),
        reinterpret_cast<const __half *>(A), lda, strideA,
        reinterpret_cast<const __half *>(B), ldb, strideB,
        reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
        ldc, strideC, batchCount));
Y
Yu Yang 已提交
235
#else
236 237
    PADDLE_THROW(platform::errors::Unimplemented(
        "HgemmStridedBatched is not supported on cuda <= 7.5"));
238 239 240 241 242 243 244 245 246 247 248 249 250 251
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const void *beta, void *C,
                      cudaDataType_t Ctype, int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
252
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
253
#if CUDA_VERSION >= 9000
254 255 256 257 258 259
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
260 261
#endif  // CUDA_VERSION >= 9000

262
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
263
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
264 265 266
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, computeType, algo));
    });
267
#else
268 269
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasGemmEx is not supported on cuda <= 7.5"));
Y
Yu Yang 已提交
270 271
#endif
  }
Y
Yu Yang 已提交
272 273
};

274
template <>
275
struct CUBlas<platform::complex<float>> {
276
  static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
277 278 279 280 281
                   int n, const platform::complex<float> *alpha,
                   const platform::complex<float> *A, int lda,
                   const platform::complex<float> *B, int ldb,
                   const platform::complex<float> *beta,
                   platform::complex<float> *C, int ldc) {
282 283 284 285 286 287 288 289
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv(
        handle, transa, m, n, reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A), lda,
        reinterpret_cast<const cuFloatComplex *>(B), ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C), ldc));
  }

290 291 292 293
  static void AXPY(cublasHandle_t handle, int n,
                   const platform::complex<float> *alpha,
                   const platform::complex<float> *X, const int incX,
                   platform::complex<float> *Y, const int incY) {
294 295 296 297 298 299
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy(
        handle, n, reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(X), incX,
        reinterpret_cast<cuFloatComplex *>(Y), incY));
  }

300 301 302
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb, int m, int n, int k,
303 304 305 306 307 308 309
                                 const platform::complex<float> *alpha,
                                 const platform::complex<float> *A, int lda,
                                 long long int strideA,              // NOLINT
                                 const platform::complex<float> *B,  // NOLINT
                                 int ldb, long long int strideB,     // NOLINT
                                 const platform::complex<float> *beta,
                                 platform::complex<float> *C, int ldc,
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemmStridedBatched(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A), lda, strideA,
        reinterpret_cast<const cuFloatComplex *>(B), ldb, strideB,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C), ldc, strideC, batchCount));
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
                   cublasOperation_t transb, int m, int n, int k,
328 329 330 331 332
                   const platform::complex<float> *alpha,
                   const platform::complex<float> *A, int lda,
                   const platform::complex<float> *B, int ldb,
                   const platform::complex<float> *beta,
                   platform::complex<float> *C, int ldc) {
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const cuFloatComplex *>(alpha),
        reinterpret_cast<const cuFloatComplex *>(A), lda,
        reinterpret_cast<const cuFloatComplex *>(B), ldb,
        reinterpret_cast<const cuFloatComplex *>(beta),
        reinterpret_cast<cuFloatComplex *>(C), ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const void *beta, void *C,
                      cudaDataType_t Ctype, int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, computeType, algo));
    });
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

template <>
376
struct CUBlas<platform::complex<double>> {
377
  static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
378 379 380 381 382
                   int n, const platform::complex<double> *alpha,
                   const platform::complex<double> *A, int lda,
                   const platform::complex<double> *B, int ldb,
                   const platform::complex<double> *beta,
                   platform::complex<double> *C, int ldc) {
383 384 385 386 387 388 389 390
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv(
        handle, transa, m, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A), lda,
        reinterpret_cast<const cuDoubleComplex *>(B), ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C), ldc));
  }

391 392 393 394
  static void AXPY(cublasHandle_t handle, int n,
                   const platform::complex<double> *alpha,
                   const platform::complex<double> *X, const int incX,
                   platform::complex<double> *Y, const int incY) {
395 396 397 398 399 400
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy(
        handle, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(X), incX,
        reinterpret_cast<cuDoubleComplex *>(Y), incY));
  }

401 402 403
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb, int m, int n, int k,
404 405 406 407 408 409 410
                                 const platform::complex<double> *alpha,
                                 const platform::complex<double> *A, int lda,
                                 long long int strideA,               // NOLINT
                                 const platform::complex<double> *B,  // NOLINT
                                 int ldb, long long int strideB,      // NOLINT
                                 const platform::complex<double> *beta,
                                 platform::complex<double> *C, int ldc,
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemmStridedBatched(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A), lda, strideA,
        reinterpret_cast<const cuDoubleComplex *>(B), ldb, strideB,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C), ldc, strideC, batchCount));
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
  }

  static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
                   cublasOperation_t transb, int m, int n, int k,
429 430 431 432 433
                   const platform::complex<double> *alpha,
                   const platform::complex<double> *A, int lda,
                   const platform::complex<double> *B, int ldb,
                   const platform::complex<double> *beta,
                   platform::complex<double> *C, int ldc) {
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const cuDoubleComplex *>(alpha),
        reinterpret_cast<const cuDoubleComplex *>(A), lda,
        reinterpret_cast<const cuDoubleComplex *>(B), ldb,
        reinterpret_cast<const cuDoubleComplex *>(beta),
        reinterpret_cast<cuDoubleComplex *>(C), ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const void *beta, void *C,
                      cudaDataType_t Ctype, int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
#endif  // CUDA_VERSION >= 9000

    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, computeType, algo));
    });
#else
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasGemmEx is not supported on cuda <= 7.5"));
#endif
  }
};

Y
Yu Yang 已提交
476 477
template <>
template <typename T>
Y
Yu Yang 已提交
478 479 480 481
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                                             CBLAS_TRANSPOSE transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             const T *B, T beta, T *C) const {
Y
Yu Yang 已提交
482 483 484 485 486 487 488 489 490
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

491 492 493 494 495 496 497 498
#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, N);
  } else {
#endif  // CUDA_VERSION >= 8000
499 500 501 502
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
                      lda, &beta, C, N);
    });
503 504 505 506

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
Y
Yu Yang 已提交
507 508 509 510 511
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Y
Yu Yang 已提交
512 513 514 515
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 *A,
    const platform::float16 *B, platform::float16 beta,
    platform::float16 *C) const {
Y
Yu Yang 已提交
516 517 518 519 520 521 522 523 524 525
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
526 527 528 529 530 531
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));
Y
Yu Yang 已提交
532 533 534 535

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

536
#if CUDA_VERSION >= 8000
Y
Yu Yang 已提交
537 538 539 540
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
541 542 543 544
  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
  CUBlas<platform::float16>::GEMM_EX(
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A,
      CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
Y
Yu Yang 已提交
545 546
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
547 548 549 550 551 552

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
                                    &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
                                    N);
  });
Y
Yu Yang 已提交
553 554 555
#endif  // CUDA_VERSION >= 8000
}

556 557 558 559
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
560 561 562
    platform::complex<float> alpha, const platform::complex<float> *A,
    const platform::complex<float> *B, platform::complex<float> beta,
    platform::complex<float> *C) const {
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
590
  CUBlas<platform::complex<float>>::GEMM_EX(
591 592 593 594 595 596
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A,
      CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
597 598 599
    CUBlas<platform::complex<float>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
                                           &c_alpha, h_B, ldb, h_A, lda,
                                           &c_beta, h_C, N);
600 601 602 603 604 605 606 607
  });
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
608 609 610
    platform::complex<double> alpha, const platform::complex<double> *A,
    const platform::complex<double> *B, platform::complex<double> beta,
    platform::complex<double> *C) const {
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

#if CUDA_VERSION >= 8000
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
639
  CUBlas<platform::complex<double>>::GEMM_EX(
640 641 642 643 644 645
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A,
      CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F);
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

  context_.CublasCall([&](cublasHandle_t handle) {
646 647 648
    CUBlas<platform::complex<double>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
                                            &c_alpha, h_B, ldb, h_A, lda,
                                            &c_beta, h_C, N);
649 650 651 652
  });
#endif  // CUDA_VERSION >= 8000
}

Y
Yu Yang 已提交
653 654
template <>
template <typename T>
Y
Yu Yang 已提交
655 656 657 658
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             int lda, const T *B, int ldb,
                                             T beta, T *C, int ldc) const {
Y
Yu Yang 已提交
659 660
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Y
Yu Yang 已提交
661 662
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
663 664 665 666 667 668 669 670 671 672

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, ldc);
  } else {
#endif  // CUDA_VERSION >= 8000

673 674 675 676
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
                      lda, &beta, C, ldc);
    });
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
    const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
    platform::float16 beta, platform::float16 *C, int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

694 695 696 697
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                    B, ldb, A, lda, &beta, C, ldc);
  });
Y
Yu Yang 已提交
698 699
}

Y
Yu Yang 已提交
700 701 702 703
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
                                             T *y) const {
704 705 706
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
Y
Yu Yang 已提交
707 708
}

709 710 711 712 713 714 715 716 717 718 719 720 721 722
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::SCAL(int n, const T alpha, T *x) const {
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::VCOPY(int n, const T *x, T *y) const {
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

Y
Yu Yang 已提交
723 724 725 726 727 728 729
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
                                             T alpha, const T *A, const T *B,
                                             T beta, T *C) const {
  cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;

730 731 732
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
Y
Yu Yang 已提交
733 734
}

S
ShenLiang 已提交
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
    bool trans_a, int M, int N, platform::float16 alpha,
    const platform::float16 *A, const platform::float16 *B,
    platform::float16 beta, platform::float16 *C) const {
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
    this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, 1, N, M,
                                           alpha, B, A, beta, C);
  } else {
    this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, M, 1, N,
                                           alpha, A, B, beta, C);
  }
}

Y
Yu Yang 已提交
751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
    int64_t strideA, int64_t strideB) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

768
#if CUDA_VERSION >= 9010
769 770
  if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same<T, float>::value)) ||
      std::is_same<T, paddle::platform::float16>::value) {
771 772 773 774 775 776 777 778
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
    bool use_tensor_op_math = context_.tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");

779
    auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
780
    context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
781
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
782 783
          handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A,
          fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo));
784
    });
785 786 787
  } else {
#endif  // CUDA_VERSION >= 9010

788 789 790 791 792
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                    B, ldb, strideB, A, lda, strideA, &beta, C,
                                    ldc, strideC, batchCount);
    });
793 794 795 796

#if CUDA_VERSION >= 9010
  }
#endif  // CUDA_VERSION >= 9010
Y
Yu Yang 已提交
797 798
}

S
ShenLiang 已提交
799 800 801 802 803 804 805 806 807 808 809
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
                           C[k]);
  }
}

S
ShenLiang 已提交
810 811 812 813 814 815 816 817 818 819 820 821 822
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 **A,
    const platform::float16 **B, platform::float16 beta, platform::float16 **C,
    int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<platform::float16>(transA, transB, M, N, K, alpha, A[k],
                                           B[k], beta, C[k]);
  }
}

G
Guo Sheng 已提交
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
                                             CBLAS_TRANSPOSE transA,
                                             CBLAS_DIAG diag, int M, int N,
                                             T alpha, const T *A, int lda, T *B,
                                             int ldb) const {
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A,
                    lda, B, ldb);
  });
}

848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRF(int n, T **a, int *ipiv,
                                                     int *info,
                                                     int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRI(int n, const T **a,
                                                     const int *ipiv, T **a_inv,
                                                     int *info,
                                                     int batch_size) const {
  PADDLE_ENFORCE_NE(
      a_inv, a,
      platform::errors::InvalidArgument(
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv, a));
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
                                                      T **a_inv, int *info,
                                                      int batch_size) const {
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}

886 887 888 889 890 891 892 893 894 895 896 897 898 899
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
    CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
    T **b, int ldb, int *info, int batch_size) const {
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTrans =
      (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
                           batch_size);
  });
}

Y
Yu Yang 已提交
900 901 902
}  // namespace math
}  // namespace operators
}  // namespace paddle