From b717606989eec034551ff6743f7b07386803e671 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Mar 2021 15:24:21 +0800 Subject: [PATCH] fix(dnn/cuda): add block size limit for culass gemm algo GitOrigin-RevId: c0940e4535169bc21f075d1600bb0e6ed30af254 --- .../cuda/matrix_mul/cutlass_float32_simt.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index 16d80c6a1..c79f81fbe 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -22,10 +22,20 @@ using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( const SizeArgs& args) const { - return args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float32() && - args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32(); + bool available = + args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32(); + int n = args.layout_c.shape[1]; + auto&& device_prop = cuda::current_device_prop(); + int y_grid_limit = device_prop.maxGridSize[1]; + // limit y grid + available &= ((n + m_algo_param.threadblock_n - 1) / + m_algo_param.threadblock_n <= + y_grid_limit); + + return available; } size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( -- GitLab