From 849f0ece9d062b391555633d9dfa6f53e91acb9b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Oct 2021 15:37:59 +0800 Subject: [PATCH] fix(dnn): drop batched matmul cublas algo when batch is 1 GitOrigin-RevId: 71126a27b07704fde6b67e9935b1f7c63d1e8200 --- dnn/src/cuda/batched_matrix_mul/cublas.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dnn/src/cuda/batched_matrix_mul/cublas.cpp b/dnn/src/cuda/batched_matrix_mul/cublas.cpp index b2261fec0..7a1a1e43b 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas.cpp @@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) auto dtype = args.layout_a.dtype; auto&& param = args.opr->param(); auto&& handle = concrete_handle(args.opr->handle()); - if (dtype == dtype::Float32()) + // fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1 + // and matricA's width > 8191 .So temporarily drop this algo when + // args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1 + if (dtype == dtype::Float32() +#if CUBLAS_VERSION < 11200 + && (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 || + args.layout_a.shape[0] != 1) +#endif + ) return true; if (dtype != dtype::Float16()) return false; -- GitLab