From 30d1ff3bb4aebba239a381fd11a8712e7028bf87 Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Tue, 21 Jul 2020 19:39:07 +0800 Subject: [PATCH] call cublasGemmStridedBatchedEx when using fp16, test=develop (#25553) --- paddle/fluid/operators/math/blas_impl.cu.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 39bddda6caa..64b35cfeaec 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -428,7 +428,8 @@ void Blas::BatchedGEMM( const int64_t strideC = M * N; #if CUDA_VERSION >= 9010 - if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || + std::is_same::value) { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = context_.tensor_core_available(); if (use_tensor_op_math) { @@ -437,11 +438,11 @@ void Blas::BatchedGEMM( VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( - handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, - strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, - strideC, batchCount, CUDA_R_32F, algo)); + handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A, + fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo)); }); } else { #endif // CUDA_VERSION >= 9010 -- GitLab