提交 a0a5fcf1 编写于 作者: M Megvii Engine Team

feat(dnn): support tf32

GitOrigin-RevId: 9e5871f933744468b91b7ab5ac6159a4b7a67084
上级 f0088335
...@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { ...@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9010 #if CUDART_VERSION >= 9010
auto io16_c32 = [&]() { auto io16_c32 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
cublas_check(cublasGemmBatchedEx( cublas_check(cublasGemmBatchedEx(
...@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { ...@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9000 #if CUDART_VERSION >= 9000
auto io16_c16 = [&]() { auto io16_c16 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device_h(); auto zero = handle->zero_device_h();
auto one = handle->one_device_h(); auto one = handle->one_device_h();
cublas_check(cublasHgemmBatched( cublas_check(cublasHgemmBatched(
......
...@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const ...@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const
batched_igemm(); batched_igemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) {
batched_hgemm(); batched_hgemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F) { } else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) {
batched_sgemm(); batched_sgemm();
} else { } else {
megdnn_throw("compute_type must be int32/float16/float32"); megdnn_throw("compute_type must be int32/float16/float32");
......
...@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { ...@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
auto sgemm = [&]() { auto sgemm = [&]() {
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
cublas_check(cublasSgemm( cublas_check(cublasSgemm(
cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0],
args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero,
args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0]));
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
}; };
auto sgemm_ex = [&]() { auto sgemm_ex = [&]() {
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
#if CUDART_VERSION >= 9000 #if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif #endif
auto sgemm_ex_err = cublasSgemmEx( auto sgemm_ex_err = cublasSgemmEx(
...@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { ...@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
}; };
auto hgemm = [&]() { auto hgemm = [&]() {
#if CUDART_VERSION >= 9000 #if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif #endif
auto one_half = handle->one_device_h(); auto one_half = handle->one_device_h();
......
...@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) { ...@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) {
case DTypeEnum::Float16: case DTypeEnum::Float16:
return CUBLAS_COMPUTE_16F; return CUBLAS_COMPUTE_16F;
case DTypeEnum::Float32: case DTypeEnum::Float32:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_32F_FAST_TF32;
case DTypeEnum::Int32: case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32: case DTypeEnum::QuantizedS32:
return CUBLAS_COMPUTE_32I; return CUBLAS_COMPUTE_32I;
......
...@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { ...@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
case CUBLAS_COMPUTE_16F: case CUBLAS_COMPUTE_16F:
hgemm(); hgemm();
break; break;
case CUBLAS_COMPUTE_32F: case CUBLAS_COMPUTE_32F_FAST_TF32:
sgemm(); sgemm();
break; break;
case CUBLAS_COMPUTE_32I: case CUBLAS_COMPUTE_32I:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册