未验证 提交 617e790a 编写于 作者: K Kexin Zhao 提交者: GitHub

fix cuda 7.5 compile error (#9885)

上级 859fedf3
...@@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
// TODO(kexinzhao): add processing code for compute capability < 53 case // TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53"); "cublas Hgemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
} }
template <> template <>
...@@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
} }
template <> template <>
...@@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
} }
template <> template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册