[Paddle-TRT] FP16 with cublasGemmStridedBatchedEx
Created by: zlsh80826
-PaddlePaddle version: develop -GPU: including CUDA/CUDNN version -OS Platform Ubuntu 16.04
Describe the feature and the current behavior/state.
The implementation in blas_impl.cu.h doesn't use the cublasGemmStridedBatchedEx
when the data type is fp16, even the CUDA_VERSION
is greater than 9010. cublasGemmStridedBatchedEx
has more capability like selecting the algorithm (reference), it also performs better performance. We should use it when available, i.e., CUDA_VERSION >= 9010
.
So the code should be like this:
#if CUDA_VERSION >= 9010
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (FLAGS_enable_cublas_tensor_op_math && use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb,
strideB, A, fp, lda, strideA, &beta, C, fp, ldc,
strideC, batchCount, fp, algo));
});
#else // CUDA_VERSION < 9010
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#endif // CUDA_VERSION >= 9010