diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 39bddda6caa532df0c6d392a9ca2e76766d38f3e..64b35cfeaecd1f88395db97d0374d919356651eb 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -428,7 +428,8 @@ void Blas::BatchedGEMM( const int64_t strideC = M * N; #if CUDA_VERSION >= 9010 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || + std::is_same::value) { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = context_.tensor_core_available(); if (use_tensor_op_math) { @@ -437,11 +438,11 @@ void Blas::BatchedGEMM( VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, - strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, - strideC, batchCount, CUDA_R_32F, algo)); + handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A, + fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo)); }); } else { #endif // CUDA_VERSION >= 9010