提交 d00bd9eb 编写于 作者: K Kexin Zhao 提交者: Yi Wang

Update the cuda API and enable tensor core for GEMM (#9622)

* change from hgemm to gemmEx

* fix cpplint
上级 517f6195
...@@ -39,18 +39,33 @@ void gemm<platform::CUDADeviceContext, float16>( ...@@ -39,18 +39,33 @@ void gemm<platform::CUDADeviceContext, float16>(
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const half h_alpha = static_cast<const half>(alpha); float h_alpha = static_cast<float>(alpha);
const half h_beta = static_cast<const half>(beta); float h_beta = static_cast<float>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case // TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53"); "cublas fp16 gemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
h_A, lda, &h_beta, h_C, N)); #if CUDA_VERSION >= 9000
if (context.GetComputeCapability() >= 70) {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} else {
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
#endif
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
} }
template <> template <>
......
...@@ -24,6 +24,10 @@ void *cublas_dso_handle = nullptr; ...@@ -24,6 +24,10 @@ void *cublas_dso_handle = nullptr;
CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2
CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,8 +15,9 @@ limitations under the License. */ ...@@ -15,8 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h>
#include <dlfcn.h> #include <dlfcn.h>
#include <mutex> #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle { namespace paddle {
...@@ -70,6 +71,7 @@ extern void *cublas_dso_handle; ...@@ -70,6 +71,7 @@ extern void *cublas_dso_handle;
__macro(cublasDgemm_v2); \ __macro(cublasDgemm_v2); \
__macro(cublasHgemm); \ __macro(cublasHgemm); \
__macro(cublasSgemmEx); \ __macro(cublasSgemmEx); \
__macro(cublasGemmEx); \
__macro(cublasSgeam_v2); \ __macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \ __macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \ __macro(cublasCreate_v2); \
...@@ -89,9 +91,15 @@ extern void *cublas_dso_handle; ...@@ -89,9 +91,15 @@ extern void *cublas_dso_handle;
__macro(cublasSgetrfBatched); \ __macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \ __macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \ __macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched) __macro(cublasDgetriBatched);
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP); CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
// APIs available after CUDA 9.0
#if CUDA_VERSION >= 9000
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) __macro(cublasSetMathMode);
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
} // namespace dynload } // namespace dynload
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册