diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index 16d80c6a1b7299b183e96113317a4e6835e6c645..c79f81fbeea694036b0b6c64b1a67e52960afb51 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -22,10 +22,20 @@ using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( const SizeArgs& args) const { - return args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float32() && - args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32(); + bool available = + args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32(); + int n = args.layout_c.shape[1]; + auto&& device_prop = cuda::current_device_prop(); + int y_grid_limit = device_prop.maxGridSize[1]; + // limit y grid + available &= ((n + m_algo_param.threadblock_n - 1) / + m_algo_param.threadblock_n <= + y_grid_limit); + + return available; } size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(