From a0a5fcf1820a84e7d6bf470948e3fadf34288f95 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 Apr 2022 17:49:03 +0800 Subject: [PATCH] feat(dnn): support tf32 GitOrigin-RevId: 9e5871f933744468b91b7ab5ac6159a4b7a67084 --- dnn/src/cuda/batched_matrix_mul/cublas.cpp | 8 ++++++++ dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp | 2 +- dnn/src/cuda/matrix_mul/cublas.cpp | 14 ++++++++++++-- dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp | 2 +- dnn/src/cuda/matrix_mul/cublas_lt.cpp | 2 +- 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/dnn/src/cuda/batched_matrix_mul/cublas.cpp b/dnn/src/cuda/batched_matrix_mul/cublas.cpp index b0bbaef0..b58a4c06 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas.cpp @@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { #if CUDART_VERSION >= 9010 auto io16_c32 = [&]() { +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#else cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); +#endif auto zero = handle->zero_device(); auto one = handle->one_device(); cublas_check(cublasGemmBatchedEx( @@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { #if CUDART_VERSION >= 9000 auto io16_c16 = [&]() { +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#else cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); +#endif auto zero = handle->zero_device_h(); auto one = handle->one_device_h(); cublas_check(cublasHgemmBatched( diff --git a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp index 6ddcfd47..bb4f702e 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp @@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const batched_igemm(); } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { batched_hgemm(); - } else if (desc.dt_compute == CUBLAS_COMPUTE_32F) { + } else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) { batched_sgemm(); } else { megdnn_throw("compute_type must be int32/float16/float32"); diff --git a/dnn/src/cuda/matrix_mul/cublas.cpp b/dnn/src/cuda/matrix_mul/cublas.cpp index 0097414a..4888566a 100644 --- a/dnn/src/cuda/matrix_mul/cublas.cpp +++ b/dnn/src/cuda/matrix_mul/cublas.cpp @@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { auto sgemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif cublas_check(cublasSgemm( cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, args.tensor_b.ptr(), args.tensor_b.layout.stride[0], args.tensor_a.ptr(), args.tensor_a.layout.stride[0], zero, args.tensor_c.ptr(), args.tensor_c.layout.stride[0])); +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); +#endif }; auto sgemm_ex = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); -#if CUDART_VERSION >= 9000 +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#elif CUDART_VERSION >= 9000 cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); #endif auto sgemm_ex_err = cublasSgemmEx( @@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { }; auto hgemm = [&]() { -#if CUDART_VERSION >= 9000 +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#elif CUDART_VERSION >= 9000 cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); #endif auto one_half = handle->one_device_h(); diff --git a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp index cca4ef8b..10601904 100644 --- a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp +++ b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp @@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) { case DTypeEnum::Float16: return CUBLAS_COMPUTE_16F; case DTypeEnum::Float32: - return CUBLAS_COMPUTE_32F; + return CUBLAS_COMPUTE_32F_FAST_TF32; case DTypeEnum::Int32: case DTypeEnum::QuantizedS32: return CUBLAS_COMPUTE_32I; diff --git a/dnn/src/cuda/matrix_mul/cublas_lt.cpp b/dnn/src/cuda/matrix_mul/cublas_lt.cpp index dc4548ae..9df439ed 100644 --- a/dnn/src/cuda/matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/matrix_mul/cublas_lt.cpp @@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { case CUBLAS_COMPUTE_16F: hgemm(); break; - case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: sgemm(); break; case CUBLAS_COMPUTE_32I: -- GitLab