From d8bb3ff5b4e0d6668c523c403d0009cc500bdb15 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Mar 2022 10:40:05 +0800 Subject: [PATCH] fix(cuda): fix fp16 tensorcore gemm split k workspace GitOrigin-RevId: d04a0e098541e7e1b519bb392966f598cde3639a --- dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp index a3ef56fb8..772e20fd1 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp @@ -55,12 +55,12 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( k = args.layout_a.shape[param.transposeA ? 0 : 1]; int split_k_slices = std::max(1, k / n); if (!aligned.first) - return args.layout_c.dtype.size(m * n * split_k_slices); + return sizeof(float) * (m * n * split_k_slices); const auto& layouts = aligned.second; int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], align_k = layouts[0].shape[1]; split_k_slices = std::max(1, align_k / align_n); - size_t ws_size = args.layout_c.dtype.size(align_m * align_n * split_k_slices); + size_t ws_size = sizeof(float) * (align_m * align_n * split_k_slices); for (auto&& ly : layouts) ws_size += ly.span().dist_byte(); return ws_size; -- GitLab