未验证 提交 617e790a 编写于 作者: K Kexin Zhao 提交者: GitHub

fix cuda 7.5 compile error (#9885)

上级 859fedf3
......@@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
......@@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
......@@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N;
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
#else
PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
}
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册