blas_impl.cu.h 17.2 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
19 20 21
#include "paddle/fluid/platform/gpu_info.h"

DECLARE_bool(enable_cublas_tensor_op_math);
Y
Yu Yang 已提交
22 23 24 25 26 27 28 29 30 31 32 33

namespace paddle {
namespace operators {
namespace math {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
34
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...));
Y
Yu Yang 已提交
35
  }
Y
Yu Yang 已提交
36 37 38

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
39
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
Y
Yu Yang 已提交
40 41
  }

42 43
  template <typename... ARGS>
  static void SCAL(ARGS... args) {
44
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSscal(args...));
45 46 47 48
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
49
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasScopy(args...));
50 51
  }

Y
Yu Yang 已提交
52 53
  template <typename... ARGS>
  static void GEMV(ARGS... args) {
54
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
Y
Yu Yang 已提交
55 56 57
  }

  template <typename... ARGS>
58
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
59
#if CUDA_VERSION >= 8000
60 61
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasSgemmStridedBatched(args...));
Y
Yu Yang 已提交
62 63
#else
    PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
64 65 66 67 68 69 70 71 72 73 74 75
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const float *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const float *beta, void *C,
                      cudaDataType_t Ctype, int ldc) {
76 77 78
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
79
#if CUDA_VERSION >= 8000
80 81 82
    VLOG(5) << "use_tensor_op_math: "
            << (dev_ctx->tensor_core_available() ? "True" : "False");
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
83
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx(
84 85 86
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc));
    });
87
#else
88
    PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
Y
Yu Yang 已提交
89 90
#endif
  }
G
Guo Sheng 已提交
91 92 93 94 95

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsm(args...));
  }
Y
Yu Yang 已提交
96 97 98 99 100 101
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
102
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...));
Y
Yu Yang 已提交
103
  }
Y
Yu Yang 已提交
104 105 106

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
107
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
Y
Yu Yang 已提交
108 109
  }

110 111
  template <typename... ARGS>
  static void SCAL(ARGS... args) {
112
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDscal(args...));
113 114 115 116
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
117
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDcopy(args...));
118 119
  }

Y
Yu Yang 已提交
120 121
  template <typename... ARGS>
  static void GEMV(ARGS... args) {
122
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
Y
Yu Yang 已提交
123 124 125
  }

  template <typename... ARGS>
126
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
127
#if CUDA_VERSION >= 8000
128 129
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cublasDgemmStridedBatched(args...));
Y
Yu Yang 已提交
130 131 132 133
#else
    PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
  }
134 135 136 137 138

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW("Currently there are not cublasDgemmEx.");
  }
G
Guo Sheng 已提交
139 140 141 142 143

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsm(args...));
  }
Y
Yu Yang 已提交
144 145 146 147
};

template <>
struct CUBlas<platform::float16> {
Y
Yu Yang 已提交
148 149 150 151 152 153 154
  using float16 = platform::float16;

  static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
                   cublasOperation_t transb, int m, int n, int k,
                   const float16 *alpha, const float16 *A, int lda,
                   const float16 *B, int ldb, const float16 *beta, float16 *C,
                   int ldc) {
155
    PADDLE_ENFORCE_CUDA_SUCCESS(
Y
Yu Yang 已提交
156 157 158 159 160 161
        platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
                                       reinterpret_cast<const __half *>(alpha),
                                       reinterpret_cast<const __half *>(A), lda,
                                       reinterpret_cast<const __half *>(B), ldb,
                                       reinterpret_cast<const __half *>(beta),
                                       reinterpret_cast<__half *>(C), ldc));
Y
Yu Yang 已提交
162
  }
Y
Yu Yang 已提交
163

164 165 166 167 168 169 170 171 172 173
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb, int m, int n, int k,
                                 const float16 *alpha, const float16 *A,
                                 int lda, long long int strideA,  // NOLINT
                                 const float16 *B,                // NOLINT
                                 int ldb, long long int strideB,  // NOLINT
                                 const float16 *beta, float16 *C, int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
Y
Yu Yang 已提交
174
#if CUDA_VERSION >= 8000
175
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched(
Y
yuyang18 已提交
176 177 178 179 180 181
        handle, transa, transb, m, n, k,
        reinterpret_cast<const __half *>(alpha),
        reinterpret_cast<const __half *>(A), lda, strideA,
        reinterpret_cast<const __half *>(B), ldb, strideB,
        reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
        ldc, strideC, batchCount));
Y
Yu Yang 已提交
182 183
#else
    PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
184 185 186 187 188 189 190 191 192 193 194 195 196 197
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const void *beta, void *C,
                      cudaDataType_t Ctype, int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
198
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
199
#if CUDA_VERSION >= 9000
200 201 202 203 204 205
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
206 207
#endif  // CUDA_VERSION >= 9000

208
    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
209
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
210 211 212
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, computeType, algo));
    });
213
#else
214
    PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
Y
Yu Yang 已提交
215 216
#endif
  }
Y
Yu Yang 已提交
217 218 219 220
};

template <>
template <typename T>
Y
Yu Yang 已提交
221 222 223 224
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                                             CBLAS_TRANSPOSE transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             const T *B, T beta, T *C) const {
Y
Yu Yang 已提交
225 226 227 228 229 230 231 232 233
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

234 235 236 237 238 239 240 241
#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, N);
  } else {
#endif  // CUDA_VERSION >= 8000
242 243 244 245
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
                      lda, &beta, C, N);
    });
246 247 248 249

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
Y
Yu Yang 已提交
250 251 252 253 254
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Y
Yu Yang 已提交
255 256 257 258
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 *A,
    const platform::float16 *B, platform::float16 beta,
    platform::float16 *C) const {
Y
Yu Yang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
                    "cublas fp16 gemm requires GPU compute capability >= 53");

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

275
#if CUDA_VERSION >= 8000
Y
Yu Yang 已提交
276 277 278 279
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
280 281 282 283
  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
  CUBlas<platform::float16>::GEMM_EX(
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A,
      CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
Y
Yu Yang 已提交
284 285
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
286 287 288 289 290 291

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
                                    &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
                                    N);
  });
Y
Yu Yang 已提交
292 293 294 295 296
#endif  // CUDA_VERSION >= 8000
}

template <>
template <typename T>
Y
Yu Yang 已提交
297 298 299 300
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             int lda, const T *B, int ldb,
                                             T beta, T *C, int ldc) const {
Y
Yu Yang 已提交
301 302
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Y
Yu Yang 已提交
303 304
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
305 306 307 308 309 310 311 312 313 314

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, ldc);
  } else {
#endif  // CUDA_VERSION >= 8000

315 316 317 318
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
                      lda, &beta, C, ldc);
    });
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
    const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
    platform::float16 beta, platform::float16 *C, int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

336 337 338 339
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                    B, ldb, A, lda, &beta, C, ldc);
  });
Y
Yu Yang 已提交
340 341
}

Y
Yu Yang 已提交
342 343 344 345
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
                                             T *y) const {
346 347 348
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
Y
Yu Yang 已提交
349 350
}

351 352 353 354 355 356 357 358 359 360 361 362 363 364
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::SCAL(int n, const T alpha, T *x) const {
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::VCOPY(int n, const T *x, T *y) const {
  context_.CublasCall(
      [&](cublasHandle_t handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

Y
Yu Yang 已提交
365 366 367 368 369 370 371
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
                                             T alpha, const T *A, const T *B,
                                             T beta, T *C) const {
  cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;

372 373 374
  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
Y
Yu Yang 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
    int64_t strideA, int64_t strideB) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

394 395
#if CUDA_VERSION >= 9010
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
396 397 398 399 400 401 402 403 404
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
    bool use_tensor_op_math = context_.tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");

    context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
405
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
406 407 408 409
          handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
          strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
          strideC, batchCount, CUDA_R_32F, algo));
    });
410 411 412
  } else {
#endif  // CUDA_VERSION >= 9010

413 414 415 416 417
    context_.CublasCall([&](cublasHandle_t handle) {
      CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                    B, ldb, strideB, A, lda, strideA, &beta, C,
                                    ldc, strideC, batchCount);
    });
418 419 420 421

#if CUDA_VERSION >= 9010
  }
#endif  // CUDA_VERSION >= 9010
Y
Yu Yang 已提交
422 423
}

G
Guo Sheng 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
                                             CBLAS_TRANSPOSE transA,
                                             CBLAS_DIAG diag, int M, int N,
                                             T alpha, const T *A, int lda, T *B,
                                             int ldb) const {
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  cublasSideMode_t cuSide =
      (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t cuUplo =
      (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasDiagType_t cuDiag =
      (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

  context_.CublasCall([&](cublasHandle_t handle) {
    CUBlas<T>::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A,
                    lda, B, ldb);
  });
}

Y
Yu Yang 已提交
449 450 451
}  // namespace math
}  // namespace operators
}  // namespace paddle