提交 3b08bd9e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/fallback): fix im2col thread safe problem

GitOrigin-RevId: f9f82d8c88379a7791aa72ea222567468df707ac
上级 3ef308e7
...@@ -47,7 +47,7 @@ protected: ...@@ -47,7 +47,7 @@ protected:
private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;
mutable size_t m_oc_block_size = 0; const size_t m_oc_block_size = 0;
}; };
} // namespace fallback } // namespace fallback
......
...@@ -350,7 +350,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ...@@ -350,7 +350,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
} }
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
const NCBKernSizeParam& param, size_t block_m, size_t block_n, const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t& ohw_tile_size, size_t block_m, size_t block_n,
bool need_pack) const { bool need_pack) const {
size_t nr_threads = param.nr_threads; size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
...@@ -360,29 +361,29 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( ...@@ -360,29 +361,29 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
//! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the //! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the
//! workspace size may change, will ocur workspace not match problem, so //! workspace size may change, will ocur workspace not match problem, so
//! should use the original data init them to avoid the problem //! should use the original data init them to avoid the problem
m_oc_tile_size = DEFAULT_OC_TILE_SIZE; oc_tile_size = DEFAULT_OC_TILE_SIZE;
m_ohw_tile_size = m_ohw_tile_origin; ohw_tile_size = m_ohw_tile_size;
m_oc_tile_size = std::min(m_oc_tile_size, OC); oc_tile_size = std::min(oc_tile_size, OC);
m_ohw_tile_size = std::min(m_ohw_tile_size, ohw); ohw_tile_size = std::min(ohw_tile_size, ohw);
if (nr_threads > 1) { if (nr_threads > 1) {
if (ohw / m_ohw_tile_size < nr_threads) { if (ohw / ohw_tile_size < nr_threads) {
m_ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n); ohw_tile_size = round_up(div_ceil(ohw, nr_threads), block_n);
if (m_ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) { if (ohw_tile_size < DEFAULT_OHW_MIN_TILE_SIZE) {
m_ohw_tile_size = ohw; ohw_tile_size = ohw;
m_oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m); oc_tile_size = round_up(div_ceil(OC, nr_threads), block_m);
if (m_oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) { if (oc_tile_size > DEFAULT_OC_MAX_TILE_SIZE) {
m_oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE; oc_tile_size = DEFAULT_OC_MAX_TILE_SIZE;
} else if (m_oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) { } else if (oc_tile_size < DEFAULT_OC_MIN_TILE_SIZE) {
m_oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE; oc_tile_size = DEFAULT_OC_MIN_TILE_SIZE;
} }
} }
} }
} else { } else {
if (!need_pack) { //! no pack ,usually in x86 save memroy if (!need_pack) { //! no pack ,usually in x86 save memroy
m_ohw_tile_size = ohw; ohw_tile_size = ohw;
m_oc_tile_size = OC; oc_tile_size = OC;
} }
} }
} }
...@@ -406,20 +407,22 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -406,20 +407,22 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t GROUP = param.filter_meta.group; size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
size_t oc_tile_size = 0, ohw_tile_size = 0;
if (need_pack || only_packA) { if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size(); auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, need_pack);
auto im2col_kern_param = get_matmul_kern_param( auto im2col_kern_param = get_matmul_kern_param(
param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC); param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param); WorkspaceBundle wb = m_matmul_algo->get_bundle(im2col_kern_param);
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
: wb.get_size(0); : wb.get_size(0);
} else { //! not support pack,not need pack } else { //! not support pack,not need pack
size_t nopack_default_blockm = 8; size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16; size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, nopack_default_blockm, nopack_default_blockn, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
need_pack); need_pack);
packa_group_size = 0; packa_group_size = 0;
} }
...@@ -434,23 +437,23 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -434,23 +437,23 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
WorkspaceBundle ws = {nullptr, {}}; WorkspaceBundle ws = {nullptr, {}};
auto im2col_kern_param = auto im2col_kern_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
ws = defaultkern.get_thread_bundle(param, im2col_kern_param, ws = defaultkern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { } else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
ws = onlypackakern.get_thread_bundle(param, im2col_kern_param, ws = onlypackakern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} else { } else {
Im2colKerns<Pack_Mode::NO_PACK> nopackkern; Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
ws = nopackkern.get_thread_bundle(param, im2col_kern_param, ws = nopackkern.get_thread_bundle(param, im2col_kern_param,
m_matmul_algo, m_ohw_tile_size, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} }
return {nullptr, return {nullptr,
...@@ -476,45 +479,59 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -476,45 +479,59 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
MEGDNN_MARK_USED_VAR(IW); MEGDNN_MARK_USED_VAR(IW);
MEGDNN_MARK_USED_VAR(FH); MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW); MEGDNN_MARK_USED_VAR(FW);
size_t oc_tile_size = 0, ohw_tile_size = 0;
size_t ohw = OH * OW; size_t ohw = OH * OW;
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size);
size_t GROUP = param.filter_meta.group; size_t GROUP = param.filter_meta.group;
WorkspaceBundle bundle = get_bundle(param); WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}}; WorkspaceBundle bundle_thread = {nullptr, {}};
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
bool need_padding = (PH != 0 || PW != 0); bool need_padding = (PH != 0 || PW != 0);
Pack_Mode packmode = m_matmul_algo->packmode(); Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT; bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK; bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
}
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
size_t packa_parallel_times = 0; size_t packa_parallel_times = 0;
size_t pack_oc_size = size_t pack_oc_size =
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1 (param.filter_meta.format == param::ConvBias::Format::NCHW ? 1
: 4); : 4);
if (only_packA) { if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
} else if (default_pack) { } else if (default_pack) {
packa_parallel_times = div_ceil<size_t>( packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size);
} }
auto matmul_param = get_matmul_kern_param( auto matmul_param = get_matmul_kern_param(
param, m_ohw_tile_size, only_packA ? m_oc_tile_size : OC); param, ohw_tile_size, only_packA ? oc_tile_size : OC);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
bundle_thread = defaultkern.get_thread_bundle( bundle_thread = defaultkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size, param, matmul_param, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { } else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
bundle_thread = onlypackakern.get_thread_bundle( bundle_thread = onlypackakern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size, param, matmul_param, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} else { } else {
Im2colKerns<Pack_Mode::NO_PACK> nopackkern; Im2colKerns<Pack_Mode::NO_PACK> nopackkern;
bundle_thread = nopackkern.get_thread_bundle( bundle_thread = nopackkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, m_ohw_tile_size, param, matmul_param, m_matmul_algo, ohw_tile_size,
m_oc_tile_size); oc_tile_size);
} }
StrategyParam strategyparam; StrategyParam strategyparam;
...@@ -524,10 +541,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -524,10 +541,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
strategyparam.is_ohw_size_bigger = (m_ohw_tile_size >= ohw); strategyparam.is_ohw_size_bigger = (ohw_tile_size >= ohw);
strategyparam.skip_copy_dst = strategyparam.skip_copy_dst =
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
strategyparam.oc_tile_size = m_oc_tile_size; strategyparam.oc_tile_size = oc_tile_size;
strategyparam.pack_oc_size = pack_oc_size; strategyparam.pack_oc_size = pack_oc_size;
SmallVector<ConvBiasImpl::NCBKern> ret_kern; SmallVector<ConvBiasImpl::NCBKern> ret_kern;
...@@ -556,7 +573,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -556,7 +573,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
auto kern_compute_default = auto kern_compute_default =
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
ohw_tile_size = m_ohw_tile_size, ohw_tile_size = ohw_tile_size,
strategyparam = strategyparam, strategyparam = strategyparam,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
...@@ -579,7 +596,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -579,7 +596,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
strategyparam = strategyparam, strategyparam = strategyparam,
ohw_tile_size = m_ohw_tile_size, ohw_tile_size = ohw_tile_size,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns( Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns(
...@@ -599,7 +616,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -599,7 +616,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
strategyparam = strategyparam, strategyparam = strategyparam,
ohw_tile_size = m_ohw_tile_size, ohw_tile_size = ohw_tile_size,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::NO_PACK>::kerns( Im2colKerns<Pack_Mode::NO_PACK>::kerns(
...@@ -650,8 +667,25 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -650,8 +667,25 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false; return false;
} }
size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
}
fallback::MatrixMulImpl::KernSizeParam matmul_param = fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param); bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable && return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW ||
......
...@@ -36,20 +36,21 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { ...@@ -36,20 +36,21 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size, const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const; size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, void choice_ohw_oc_block(const NCBKernSizeParam& param,
size_t block_n, bool pack_default) const; size_t& oc_tile_size, size_t& ohw_tile_size,
size_t block_m, size_t block_n,
bool pack_default) const;
public: public:
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size)
: m_matmul_algo(matmul_algo), : m_matmul_algo(matmul_algo),
m_ohw_tile_origin(ohw_tile_size),
m_ohw_tile_size(ohw_tile_size) {} m_ohw_tile_size(ohw_tile_size) {}
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { const char* name() const override {
if (m_name.empty()) { if (m_name.empty()) {
m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(), m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(),
m_ohw_tile_origin); m_ohw_tile_size);
} }
return m_name.c_str(); return m_name.c_str();
} }
...@@ -72,9 +73,7 @@ public: ...@@ -72,9 +73,7 @@ public:
private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;
const size_t m_ohw_tile_origin; const size_t m_ohw_tile_size;
mutable size_t m_ohw_tile_size;
mutable size_t m_oc_tile_size = DEFAULT_OC_TILE_SIZE;
}; };
} // namespace fallback } // namespace fallback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册