diff --git a/dnn/src/cuda/batched_matrix_mul/cublas.cpp b/dnn/src/cuda/batched_matrix_mul/cublas.cpp index b0bbaef05609c743d614519b1558fc03e489c290..b58a4c066c46c197ce1583409cf4abc52e75fb67 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas.cpp @@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { #if CUDART_VERSION >= 9010 auto io16_c32 = [&]() { +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#else cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); +#endif auto zero = handle->zero_device(); auto one = handle->one_device(); cublas_check(cublasGemmBatchedEx( @@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { #if CUDART_VERSION >= 9000 auto io16_c16 = [&]() { +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#else cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); +#endif auto zero = handle->zero_device_h(); auto one = handle->one_device_h(); cublas_check(cublasHgemmBatched( diff --git a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp index 6ddcfd479750df29967f1326e070f00b78aae283..bb4f702e5b64b7de86994093178cf460f53e4e6e 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp @@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const batched_igemm(); } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { batched_hgemm(); - } else if (desc.dt_compute == CUBLAS_COMPUTE_32F) { + } else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) { batched_sgemm(); } else { megdnn_throw("compute_type must be int32/float16/float32"); diff --git a/dnn/src/cuda/matrix_mul/cublas.cpp b/dnn/src/cuda/matrix_mul/cublas.cpp index 0097414a5bb601cdf95faac171558a354c69fdd3..4888566af7c80d1b003cb842cd34d150dc43448c 100644 --- a/dnn/src/cuda/matrix_mul/cublas.cpp +++ b/dnn/src/cuda/matrix_mul/cublas.cpp @@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { auto sgemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif cublas_check(cublasSgemm( cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, args.tensor_b.ptr(), args.tensor_b.layout.stride[0], args.tensor_a.ptr(), args.tensor_a.layout.stride[0], zero, args.tensor_c.ptr(), args.tensor_c.layout.stride[0])); +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); +#endif }; auto sgemm_ex = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); -#if CUDART_VERSION >= 9000 +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#elif CUDART_VERSION >= 9000 cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); #endif auto sgemm_ex_err = cublasSgemmEx( @@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { }; auto hgemm = [&]() { -#if CUDART_VERSION >= 9000 +#if CUDART_VERSION >= 11000 + cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); +#elif CUDART_VERSION >= 9000 cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); #endif auto one_half = handle->one_device_h(); diff --git a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp index cca4ef8b354b46309cd107cc205f668ce9a6266f..106019047f177671b4a2a813682027a41658738e 100644 --- a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp +++ b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp @@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) { case DTypeEnum::Float16: return CUBLAS_COMPUTE_16F; case DTypeEnum::Float32: - return CUBLAS_COMPUTE_32F; + return CUBLAS_COMPUTE_32F_FAST_TF32; case DTypeEnum::Int32: case DTypeEnum::QuantizedS32: return CUBLAS_COMPUTE_32I; diff --git a/dnn/src/cuda/matrix_mul/cublas_lt.cpp b/dnn/src/cuda/matrix_mul/cublas_lt.cpp index dc4548aef82f1d8f352ec4a5dcc8fa722b3a1e28..9df439ed65decc6a47ec34071821c258b7c64e26 100644 --- a/dnn/src/cuda/matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/matrix_mul/cublas_lt.cpp @@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { case CUBLAS_COMPUTE_16F: hgemm(); break; - case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: sgemm(); break; case CUBLAS_COMPUTE_32I: