diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index e4f3dbf6a791cd074c9cd5b39652f8f2b6a25e47..8155a79d91a818cf092cf36b700650a5ebb0d07e 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -1452,7 +1452,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, << FLAGS_gemm_use_half_precision_compute_type; auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - cudaDataType_t compute_type = CUDA_R_32F; +#if CUDA_VERSION >= 11000 + auto compute_type = CUBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); @@ -1463,7 +1467,11 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, std::is_same::value) { a = static_cast(&alpha); b = static_cast(&beta); +#if CUDA_VERSION >= 11000 + compute_type = CUBLAS_COMPUTE_16F; +#else compute_type = CUDA_R_16F; +#endif } context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {