blas_impl.cu.h 14.2 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
19 20 21
#include "paddle/fluid/platform/gpu_info.h"

DECLARE_bool(enable_cublas_tensor_op_math);
Y
Yu Yang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35

namespace paddle {
namespace operators {
namespace math {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
  }
Y
Yu Yang 已提交
36 37 38 39 40 41 42 43 44 45 46 47

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...));
  }

  template <typename... ARGS>
48
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
49 50 51 52
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
#else
    PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
53 54 55 56 57 58 59 60 61 62 63 64
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const float *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const float *beta, void *C,
                      cudaDataType_t Ctype, int ldc) {
S
sneaxiy 已提交
65 66 67
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
68
#if CUDA_VERSION >= 8000
S
sneaxiy 已提交
69 70 71 72 73
    VLOG(5) << "use_tensor_op_math: "
            << (dev_ctx->tensor_core_available() ? "True" : "False");
    PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
        dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
        alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
74
#else
S
sneaxiy 已提交
75
    PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
Y
Yu Yang 已提交
76 77
#endif
  }
Y
Yu Yang 已提交
78 79 80 81 82 83 84 85
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
  }
Y
Yu Yang 已提交
86 87 88 89 90 91 92 93 94 95 96 97

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...));
  }

  template <typename... ARGS>
98
  static void GEMM_STRIDED_BATCH(ARGS... args) {
Y
Yu Yang 已提交
99 100 101 102 103 104
#if CUDA_VERSION >= 8000
    PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
#else
    PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
  }
105 106 107 108 109

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW("Currently there are not cublasDgemmEx.");
  }
Y
Yu Yang 已提交
110 111 112 113
};

template <>
struct CUBlas<platform::float16> {
Y
Yu Yang 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127
  using float16 = platform::float16;

  static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
                   cublasOperation_t transb, int m, int n, int k,
                   const float16 *alpha, const float16 *A, int lda,
                   const float16 *B, int ldb, const float16 *beta, float16 *C,
                   int ldc) {
    PADDLE_ENFORCE(
        platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
                                       reinterpret_cast<const __half *>(alpha),
                                       reinterpret_cast<const __half *>(A), lda,
                                       reinterpret_cast<const __half *>(B), ldb,
                                       reinterpret_cast<const __half *>(beta),
                                       reinterpret_cast<__half *>(C), ldc));
Y
Yu Yang 已提交
128
  }
Y
Yu Yang 已提交
129

130 131 132 133 134 135 136 137 138 139
  static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
                                 cublasOperation_t transa,
                                 cublasOperation_t transb, int m, int n, int k,
                                 const float16 *alpha, const float16 *A,
                                 int lda, long long int strideA,  // NOLINT
                                 const float16 *B,                // NOLINT
                                 int ldb, long long int strideB,  // NOLINT
                                 const float16 *beta, float16 *C, int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
Y
Yu Yang 已提交
140
#if CUDA_VERSION >= 8000
Y
yuyang18 已提交
141 142 143 144 145 146 147
    PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const __half *>(alpha),
        reinterpret_cast<const __half *>(A), lda, strideA,
        reinterpret_cast<const __half *>(B), ldb, strideB,
        reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
        ldc, strideC, batchCount));
Y
Yu Yang 已提交
148 149
#else
    PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
150 151 152 153 154 155 156 157 158 159 160 161 162 163
#endif
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      cudaDataType_t Atype, int lda, const void *B,
                      cudaDataType_t Btype, int ldb, const void *beta, void *C,
                      cudaDataType_t Ctype, int ldc,
                      cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
S
sneaxiy 已提交
164
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
165
#if CUDA_VERSION >= 9000
S
sneaxiy 已提交
166 167 168 169 170 171
    bool use_tensor_op_math = dev_ctx->tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");
172 173
#endif  // CUDA_VERSION >= 9000

S
sneaxiy 已提交
174 175 176 177
    PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
        dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
        alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
        algo));
178
#else
S
sneaxiy 已提交
179
    PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
Y
Yu Yang 已提交
180 181
#endif
  }
Y
Yu Yang 已提交
182 183 184 185
};

template <>
template <typename T>
Y
Yu Yang 已提交
186 187 188 189
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                                             CBLAS_TRANSPOSE transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             const T *B, T beta, T *C) const {
Y
Yu Yang 已提交
190 191 192 193 194 195 196 197 198
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, N);
  } else {
#endif  // CUDA_VERSION >= 8000

    CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
                    &alpha, B, ldb, A, lda, &beta, C, N);

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
Y
Yu Yang 已提交
214 215 216 217 218
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Y
Yu Yang 已提交
219 220 221 222
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 *A,
    const platform::float16 *B, platform::float16 beta,
    platform::float16 *C) const {
Y
Yu Yang 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
                    "cublas fp16 gemm requires GPU compute capability >= 53");

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

239
#if CUDA_VERSION >= 8000
Y
Yu Yang 已提交
240 241 242 243
  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
244 245 246 247
  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
  CUBlas<platform::float16>::GEMM_EX(
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A,
      CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
Y
Yu Yang 已提交
248 249
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
Y
Yu Yang 已提交
250 251 252
  CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
                                  N, M, K, &h_alpha, h_B, ldb, h_A, lda,
                                  &h_beta, h_C, N);
Y
Yu Yang 已提交
253 254 255 256 257
#endif  // CUDA_VERSION >= 8000
}

template <>
template <typename T>
Y
Yu Yang 已提交
258 259 260 261
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             int lda, const T *B, int ldb,
                                             T beta, T *C, int ldc) const {
Y
Yu Yang 已提交
262 263
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Y
Yu Yang 已提交
264 265
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

#if CUDA_VERSION >= 8000
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
    auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
    CUBlas<T>::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B,
                       CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C,
                       CUDA_R_32F, ldc);
  } else {
#endif  // CUDA_VERSION >= 8000

    CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
                    &alpha, B, ldb, A, lda, &beta, C, ldc);

#if CUDA_VERSION >= 8000
  }
#endif  // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
    const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
    platform::float16 beta, platform::float16 *C, int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

  CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
                                  N, M, K, &alpha, B, ldb, A, lda, &beta, C,
                                  ldc);
Y
Yu Yang 已提交
298 299
}

Y
Yu Yang 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
                                             T *y) const {
  CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
                                             T alpha, const T *A, const T *B,
                                             T beta, T *C) const {
  cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;

  CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
                  &beta, C, 1);
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
    int64_t strideA, int64_t strideB) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int64_t strideC = M * N;

335 336
#if CUDA_VERSION >= 9010
  if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
S
sneaxiy 已提交
337 338 339 340 341 342 343 344 345 346 347 348
    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
    bool use_tensor_op_math = context_.tensor_core_available();
    if (use_tensor_op_math) {
      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
    }
    VLOG(5) << "use_tensor_op_math: "
            << (use_tensor_op_math ? "True" : "False");

    PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
        context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
        K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
        &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
349 350 351 352 353 354 355 356 357 358
  } else {
#endif  // CUDA_VERSION >= 9010

    CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
                                  N, M, K, &alpha, B, ldb, strideB, A, lda,
                                  strideA, &beta, C, ldc, strideC, batchCount);

#if CUDA_VERSION >= 9010
  }
#endif  // CUDA_VERSION >= 9010
Y
Yu Yang 已提交
359 360
}

Y
Yu Yang 已提交
361 362 363
}  // namespace math
}  // namespace operators
}  // namespace paddle