From 954be40d0d4c04299b274e26545418b3d4bf6bed Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 3 Nov 2022 13:56:41 +0800 Subject: [PATCH] fix gemm compute_type (#47613) --- paddle/phi/kernels/funcs/blas/blas_impl.cu.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index e4f3dbf6a79..8155a79d91a 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -1452,7 +1452,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, << FLAGS_gemm_use_half_precision_compute_type; auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - cudaDataType_t compute_type = CUDA_R_32F; +#if CUDA_VERSION >= 11000 + auto compute_type = CUBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); @@ -1463,7 +1467,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, std::is_same::value) { a = static_cast(&alpha); b = static_cast(&beta); +#if CUDA_VERSION >= 11000 + compute_type = CUBLAS_COMPUTE_16F; +#else compute_type = CUDA_R_16F; +#endif } context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { -- GitLab