提交 3b08bd9e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/fallback): fix im2col thread safe problem

GitOrigin-RevId: f9f82d8c88379a7791aa72ea222567468df707ac
上级 3ef308e7
......@@ -47,7 +47,7 @@ protected:
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
mutable size_t m_oc_block_size = 0;
const size_t m_oc_block_size = 0;
};
} // namespace fallback
......
......@@ -350,7 +350,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
}
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
const NCBKernSizeParam& param, size_t block_m, size_t block_n,
const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t& ohw_tile_size, size_t block_m, size_t block_n,
bool need_pack) const {
size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg;
......@@ -360,29 +361,29 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
//! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the
//! workspace size may change, will ocur workspace not match problem, so
//! should use the original data init them to avoid the problem
m_oc_tile_size = DEFAULT_OC_TILE_SIZE;
m_ohw_tile_size = m_ohw_tile_origin;
oc_tile_size = DEFAULT_OC_TILE_SIZE;
ohw_tile_size = m_ohw_tile_size;
m_oc_tile_size = std::min(m_oc_tile_size, OC);
m_ohw_tile_size = std::min(m_ohw_tile_size, ohw);
oc_tile_size = std::min(oc_tile_size, OC);
ohw_tile_size = std::min(ohw_tile_size, ohw);
if (nr_threads > 1) {
if (ohw / m_ohw_tile_size < nr_threads) {
m_ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
if (m_ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
m_ohw_tile_size = ohw;
m_oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
if (m_oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
m_oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
} else if (m_oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
m_oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
if (ohw / ohw_tile_size < nr_threads) {
ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
if (ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
ohw_tile_size = ohw;
oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
if (oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
} else if (oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
}
}
}
} else {
if (!need_pack) { //! no pack ,usually in x86 save memroy
m_ohw_tile_size = ohw;
m_oc_tile_size = OC;
ohw_tile_size = ohw;
oc_tile_size = OC;
}
}
}
......@@ -406,20 +407,22 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
size_t oc_tile_size = 0, ohw_tile_size = 0;
if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack);
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, need_pack);
auto im2col_kern_param = get_matmul_kern_param(
param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param);
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
: wb.get_size(0);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, nopack_default_blockm, nopack_default_blockn,
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
need_pack);
packa_group_size = 0;
}
......@@ -434,23 +437,23 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
WorkspaceBundle ws = {nullptr, {}};
auto im2col_kern_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
ws = defaultkern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
m_matmul_algo, ohw_tile_size,
oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
ws = onlypackakern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
m_matmul_algo, ohw_tile_size,
oc_tile_size);
} else {
Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
ws = nopackkern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
m_matmul_algo, ohw_tile_size,
oc_tile_size);
}
return {nullptr,
......@@ -476,45 +479,59 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
MEGDNN_MARK_USED_VAR(IW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
size_t oc_tile_size = 0, ohw_tile_size = 0;
size_t ohw = OH * OW;
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size);
size_t GROUP = param.filter_meta.group;
WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}};
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
bool need_padding = (PH != 0 || PW != 0);
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
}
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
size_t packa_parallel_times = 0;
size_t pack_oc_size =
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1
: 4);
if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
} else if (default_pack) {
packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size);
}
auto matmul_param = get_matmul_kern_param(
param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC);
param, ohw_tile_size, only_packA ? oc_tile_size : OC);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
bundle_thread = defaultkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
param, matmul_param, m_matmul_algo, ohw_tile_size,
oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
bundle_thread = onlypackakern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
param, matmul_param, m_matmul_algo, ohw_tile_size,
oc_tile_size);
} else {
Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
bundle_thread = nopackkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
param, matmul_param, m_matmul_algo, ohw_tile_size,
oc_tile_size);
}
StrategyParam strategyparam;
......@@ -524,10 +541,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
strategyparam.is_ohw_size_bigger = (m_ohw_tile_size >= ohw);
strategyparam.is_ohw_size_bigger = (ohw_tile_size >= ohw);
strategyparam.skip_copy_dst =
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
strategyparam.oc_tile_size = m_oc_tile_size;
strategyparam.oc_tile_size = oc_tile_size;
strategyparam.pack_oc_size = pack_oc_size;
SmallVector<ConvBiasImpl::NCBKern> ret_kern;
......@@ -556,7 +573,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
auto kern_compute_default =
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
ohw_tile_size = m_ohw_tile_size,
ohw_tile_size = ohw_tile_size,
strategyparam = strategyparam,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
......@@ -579,7 +596,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
strategyparam = strategyparam,
ohw_tile_size = m_ohw_tile_size,
ohw_tile_size = ohw_tile_size,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns(
......@@ -599,7 +616,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
strategyparam = strategyparam,
ohw_tile_size = m_ohw_tile_size,
ohw_tile_size = ohw_tile_size,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::NO_PACK>::kerns(
......@@ -650,8 +667,25 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}
size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
}
fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW ||
......
......@@ -36,20 +36,21 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m,
size_t block_n, bool pack_default) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param,
size_t& oc_tile_size, size_t& ohw_tile_size,
size_t block_m, size_t block_n,
bool pack_default) const;
public:
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size)
: m_matmul_algo(matmul_algo),
m_ohw_tile_origin(ohw_tile_size),
m_ohw_tile_size(ohw_tile_size) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
if (m_name.empty()) {
m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(),
m_ohw_tile_origin);
m_ohw_tile_size);
}
return m_name.c_str();
}
......@@ -72,9 +73,7 @@ public:
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
const size_t m_ohw_tile_origin;
mutable size_t m_ohw_tile_size;
mutable size_t m_oc_tile_size = DEFAULT_OC_TILE_SIZE;
const size_t m_ohw_tile_size;
};
} // namespace fallback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册