提交 7ed457e7 编写于 作者: K Kexin Zhao 提交者: qingqing01

Fix cuda 7.5 error with cublas GEMM (#9811)

* fix gemm error for cuda 7.5

* fix version number
上级 20f202ac
......@@ -39,13 +39,14 @@ void gemm<platform::CUDADeviceContext, float16>(
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context.GetComputeCapability() >= 70) {
......@@ -56,7 +57,7 @@ void gemm<platform::CUDADeviceContext, float16>(
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
#endif
#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
......@@ -66,6 +67,18 @@ void gemm<platform::CUDADeviceContext, float16>(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half* h_A = reinterpret_cast<const half*>(A);
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N));
#endif // CUDA_VERSION >= 8000
}
template <>
......
......@@ -28,6 +28,10 @@ CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
#endif
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R3
CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP);
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -71,7 +71,6 @@ extern void *cublas_dso_handle;
__macro(cublasDgemm_v2); \
__macro(cublasHgemm); \
__macro(cublasSgemmEx); \
__macro(cublasGemmEx); \
__macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \
......@@ -83,11 +82,6 @@ extern void *cublas_dso_handle;
__macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \
__macro(cublasSgemmStridedBatched); \
__macro(cublasDgemmStridedBatched); \
__macro(cublasCgemmStridedBatched); \
__macro(cublasZgemmStridedBatched); \
__macro(cublasHgemmStridedBatched); \
__macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
......@@ -95,10 +89,24 @@ extern void *cublas_dso_handle;
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
// APIs available after CUDA 8.0
#if CUDA_VERSION >= 8000
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
__macro(cublasGemmEx); \
__macro(cublasSgemmStridedBatched); \
__macro(cublasDgemmStridedBatched); \
__macro(cublasCgemmStridedBatched); \
__macro(cublasZgemmStridedBatched); \
__macro(cublasHgemmStridedBatched);
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif
// APIs available after CUDA 9.0
#if CUDA_VERSION >= 9000
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) __macro(cublasSetMathMode);
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode);
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册