提交 d8bb3ff5 编写于 作者: M Megvii Engine Team

fix(cuda): fix fp16 tensorcore gemm split k workspace

GitOrigin-RevId: d04a0e098541e7e1b519bb392966f598cde3639a
上级 597efed4
......@@ -55,12 +55,12 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes(
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int split_k_slices = std::max(1, k / n);
if (!aligned.first)
return args.layout_c.dtype.size(m * n * split_k_slices);
return sizeof(float) * (m * n * split_k_slices);
const auto& layouts = aligned.second;
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1],
align_k = layouts[0].shape[1];
split_k_slices = std::max(1, align_k / align_n);
size_t ws_size = args.layout_c.dtype.size(align_m * align_n * split_k_slices);
size_t ws_size = sizeof(float) * (align_m * align_n * split_k_slices);
for (auto&& ly : layouts)
ws_size += ly.span().dist_byte();
return ws_size;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册