提交 004cb8cd 编写于 作者: M Megvii Engine Team

fix(dnn/cpu): fix the workspace calculation erorr of fallbacl im2col

GitOrigin-RevId: a718daac11525668a5c6f3ee8a97fa71d47a73ab
上级 87447df0
......@@ -17,7 +17,7 @@ WorkspaceBundle get_thread_bundle(
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t matmul_dst_bytes_per_thread =
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0;
is_dst_8bit ? oc_tile_size * OH * OW * param.bias_type.size() : 0;
return WorkspaceBundle{nullptr, {matmul_c_size, matmul_dst_bytes_per_thread}};
}
......
......@@ -199,8 +199,7 @@ static WorkspaceBundle get_bundle(
if (no_need_pading) {
padding = 0; //! not need padding
} else {
padding =
(GROUP * N * IC * IH2 * IW2) * sizeof(param.src_type); //! for padding
padding = (GROUP * N * IC * IH2 * IW2) * param.src_type.size(); //! for padding
}
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
......
......@@ -146,15 +146,15 @@ public:
size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa");
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * param.src_type.size();
size_t matmul_dst_size =
pack_oc_size * oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
pack_oc_size * oc_tile_size * ohw_tile_size * param.bias_type.size();
//! matmul_dst and im2col_dst use the same memory
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
packb = wb.get_size(1);
im2col = std::max(im2col_dst_size, matmul_dst_size);
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
bias_temp = oc_tile_size * ohw_tile_size * param.bias_type.size();
}
return {nullptr, {packb, im2col, bias_temp}};
}
......@@ -231,15 +231,15 @@ public:
size_t im2col = 0, packb = 0, matmul_dst = 0, bias_temp = 0;
bool only_packA = matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
megdnn_assert(only_packA, "onlysupport onlypackA mode");
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t matmul_dst_size = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * param.src_type.size();
size_t matmul_dst_size = oc_tile_size * ohw_tile_size * param.bias_type.size();
//! matmul_dst and im2col_dst use the same memory
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
packb = wb.get_size(1);
im2col = im2col_dst_size;
matmul_dst = matmul_dst_size;
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
bias_temp = oc_tile_size * ohw_tile_size * param.bias_type.size();
}
return {nullptr, {packb, im2col, matmul_dst, bias_temp}};
......@@ -309,8 +309,8 @@ public:
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t matmul_dst_size = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
size_t im2col_dst_size = IC * FH * FW * ohw_tile_size * param.src_type.size();
size_t matmul_dst_size = oc_tile_size * ohw_tile_size * param.bias_type.size();
im2col = im2col_dst_size;
if (is_dst_8bit) {
matmul_dst = matmul_dst_size;
......@@ -319,9 +319,8 @@ public:
}
matmul_compute = matmul_algo->get_workspace(im2col_kern_param);
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_temp = oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
bias_temp = oc_tile_size * ohw_tile_size * param.bias_type.size();
}
return {nullptr, {im2col, matmul_dst, bias_temp, matmul_compute}};
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册