提交 849f0ece 编写于 作者: M Megvii Engine Team

fix(dnn): drop batched matmul cublas algo when batch is 1

GitOrigin-RevId: 71126a27b07704fde6b67e9935b1f7c63d1e8200
上级 b5bf56e0
...@@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) ...@@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args)
auto dtype = args.layout_a.dtype; auto dtype = args.layout_a.dtype;
auto&& param = args.opr->param(); auto&& param = args.opr->param();
auto&& handle = concrete_handle(args.opr->handle()); auto&& handle = concrete_handle(args.opr->handle());
if (dtype == dtype::Float32()) // fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1
// and matricA's width > 8191 .So temporarily drop this algo when
// args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1
if (dtype == dtype::Float32()
#if CUBLAS_VERSION < 11200
&& (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 ||
args.layout_a.shape[0] != 1)
#endif
)
return true; return true;
if (dtype != dtype::Float16()) if (dtype != dtype::Float16())
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册