提交 8e9fa80c 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add matmul description for im2col

GitOrigin-RevId: 5bde0b60f0b8102cd8bad14457cf123bc7e6dafa
上级 af3de7e1
...@@ -60,6 +60,7 @@ public: ...@@ -60,6 +60,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4)
}; };
class MatrixMulImpl::AlgoF32Gemv final class MatrixMulImpl::AlgoF32Gemv final
...@@ -86,6 +87,7 @@ public: ...@@ -86,6 +87,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
}; };
#endif #endif
...@@ -207,6 +209,7 @@ public: ...@@ -207,6 +209,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
}; };
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
...@@ -234,6 +237,7 @@ public: ...@@ -234,6 +237,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
}; };
#else #else
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "src/arm_common/matrix_mul/opr_impl.h" #include "src/arm_common/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
...@@ -25,6 +26,7 @@ public: ...@@ -25,6 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
...@@ -38,6 +40,7 @@ public: ...@@ -38,6 +40,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
}; };
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { class MatrixMulImpl::AlgoF32Gemv : public AlgoBase {
...@@ -54,6 +57,7 @@ public: ...@@ -54,6 +57,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...@@ -68,6 +72,7 @@ public: ...@@ -68,6 +72,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
}; };
#endif #endif
...@@ -82,6 +87,7 @@ public: ...@@ -82,6 +87,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
......
...@@ -49,6 +49,7 @@ public: ...@@ -49,6 +49,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4)
}; };
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...@@ -71,6 +72,7 @@ public: ...@@ -71,6 +72,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
}; };
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
...@@ -190,6 +192,7 @@ public: ...@@ -190,6 +192,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
}; };
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
......
...@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle, ...@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle,
} }
//! packA_kern //! packA_kern
static void packA_kern(WorkspaceBundle bundle, static void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy, size_t pack_oc_size) { StrategyBase* im2colstrategy,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
size_t pack_oc_size) {
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo,
ncb_index, pack_oc_size); ncb_index, matmul_desc, pack_oc_size);
} }
/*! /*!
...@@ -72,7 +75,8 @@ public: ...@@ -72,7 +75,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param, const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam, StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index, fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) { size_t ohw_tile_size, StrategyBase* im2colstrategy) {
...@@ -111,7 +115,8 @@ public: ...@@ -111,7 +115,8 @@ public:
//! 2.packb and matmul compute //! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index); matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need //! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
...@@ -151,7 +156,8 @@ public: ...@@ -151,7 +156,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param, const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam, StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index, fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) { size_t ohw_tile_size, StrategyBase* im2colstrategy) {
...@@ -191,7 +197,8 @@ public: ...@@ -191,7 +197,8 @@ public:
//! 2.packb and matmul compute //! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index); matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need //! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
...@@ -232,7 +239,8 @@ public: ...@@ -232,7 +239,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param, const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam, StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index, fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) { size_t ohw_tile_size, StrategyBase* im2colstrategy) {
...@@ -272,7 +280,8 @@ public: ...@@ -272,7 +280,8 @@ public:
//! 2.packb and matmul compute //! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index); matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need //! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
...@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t padding = 0, packa_size = 0, packa_group_size = 0; size_t padding = 0, packa_size = 0, packa_group_size = 0;
size_t nr_threads = param.nr_threads; size_t nr_threads = param.nr_threads;
size_t GROUP = param.filter_meta.group; size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; m_matmul_algo->matmul_description();
bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT;
bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA;
size_t oc_tile_size = 0, ohw_tile_size = 0; size_t oc_tile_size = 0, ohw_tile_size = 0;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
mdesc.innerblocksize.m, mdesc.innerblocksize.n,
mdesc.packmode);
if (need_pack || only_packA) { if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, m_matmul_algo->packmode());
auto im2col_kern_param = get_matmul_kern_param( auto im2col_kern_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC); param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
...@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
: wb.get_size(0); : wb.get_size(0);
} else { //! not support pack,not need pack } else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
m_matmul_algo->packmode());
packa_group_size = 0; packa_group_size = 0;
} }
...@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
WorkspaceBundle bundle = get_bundle(param); WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}}; WorkspaceBundle bundle_thread = {nullptr, {}};
bool need_padding = (PH != 0 || PW != 0); bool need_padding = (PH != 0 || PW != 0);
Pack_Mode packmode = m_matmul_algo->packmode();
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
m_matmul_algo->matmul_description();
Pack_Mode packmode = mdesc.packmode;
bool default_pack = packmode == Pack_Mode::DEFAULT; bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK; bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! nopack_mode
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn, mdesc.innerblocksize.m, mdesc.innerblocksize.n,
m_matmul_algo->packmode()); mdesc.packmode);
}
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
...@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
if (only_packA) { if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
} else if (default_pack) { } else if (default_pack) {
packa_parallel_times = div_ceil<size_t>( packa_parallel_times = div_ceil<size_t>(OC, mdesc.innerblocksize.m);
OC, m_matmul_algo->get_inner_block_size().m);
} }
auto matmul_param = get_matmul_kern_param( auto matmul_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC); param, ohw_tile_size, only_packA ? oc_tile_size : OC);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { if (mdesc.packmode == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
bundle_thread = defaultkern.get_thread_bundle( bundle_thread = defaultkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, ohw_tile_size, param, matmul_param, m_matmul_algo, ohw_tile_size,
oc_tile_size); oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { } else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
bundle_thread = onlypackakern.get_thread_bundle( bundle_thread = onlypackakern.get_thread_bundle(
param, matmul_param, m_matmul_algo, ohw_tile_size, param, matmul_param, m_matmul_algo, ohw_tile_size,
...@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
auto kern_packA = [bundle, matmul_algo = m_matmul_algo, auto kern_packA = [bundle, matmul_algo = m_matmul_algo,
matmul_param, im2colstrategy, matmul_param, im2colstrategy,
pack_oc_size = pack_oc_size]( pack_oc_size = pack_oc_size,
const NCBKernParam& param, mdesc = mdesc](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index,
im2colstrategy, pack_oc_size); im2colstrategy, mdesc, pack_oc_size);
}; };
if (default_pack) { if (default_pack) {
auto kern_compute_default = auto kern_compute_default =
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
ohw_tile_size = ohw_tile_size, ohw_tile_size = ohw_tile_size,
strategyparam = strategyparam, strategyparam = strategyparam, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::DEFAULT>::kerns( Im2colKerns<Pack_Mode::DEFAULT>::kerns(
bundle, bundle_thread, param, matmul_param, bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index, matmul_algo, matmul_desc, strategyparam,
ohw_tile_size, im2colstrategy); ncb_index, ohw_tile_size, im2colstrategy);
}; };
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
...@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
strategyparam = strategyparam, strategyparam = strategyparam,
ohw_tile_size = ohw_tile_size, ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns( Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns(
bundle, bundle_thread, param, matmul_param, bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index, matmul_algo, matmul_desc, strategyparam,
ohw_tile_size, im2colstrategy); ncb_index, ohw_tile_size, im2colstrategy);
}; };
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
if (need_padding) { if (need_padding) {
...@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param, [bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo, matmul_algo = m_matmul_algo,
strategyparam = strategyparam, strategyparam = strategyparam,
ohw_tile_size = ohw_tile_size, ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param, im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) { const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::NO_PACK>::kerns( Im2colKerns<Pack_Mode::NO_PACK>::kerns(
bundle, bundle_thread, param, matmul_param, bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index, matmul_algo, matmul_desc, strategyparam,
ohw_tile_size, im2colstrategy); ncb_index, ohw_tile_size, im2colstrategy);
}; };
if (need_padding) { if (need_padding) {
...@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false; return false;
} }
} }
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
m_matmul_algo->matmul_description();
if (opr->param().format == param::ConvBias::Format::NCHW44 || if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul //! current NCHW44 im2col only support DEFAULT mode matmul
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { if (mdesc.packmode != Pack_Mode::DEFAULT) {
return false; return false;
//! nchw44 hybird mode and channel wise is not support //! nchw44 hybird mode and channel wise is not support
} else if (param.filter_meta.icpg < 4_z || } else if (param.filter_meta.icpg < 4_z ||
...@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable(
} }
size_t oc_tile_size = 0, ohw_tile_size = 0; size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn, mdesc.innerblocksize.m, mdesc.innerblocksize.n,
m_matmul_algo->packmode()); m_matmul_algo->packmode());
}
fallback::MatrixMulImpl::KernSizeParam matmul_param = fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param); bool matmulusable = m_matmul_algo->usable(matmul_param);
......
...@@ -58,8 +58,9 @@ public: ...@@ -58,8 +58,9 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desec,
size_t pack_size) = 0; size_t pack_size) = 0;
virtual void exec_im2col( virtual void exec_im2col(
...@@ -67,15 +68,17 @@ public: ...@@ -67,15 +68,17 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0;
virtual void exec_matmul( virtual void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) = 0;
virtual void exec_postprocess( virtual void exec_postprocess(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
...@@ -284,26 +287,30 @@ public: ...@@ -284,26 +287,30 @@ public:
Strategy() = default; Strategy() = default;
virtual void packA_kern(WorkspaceBundle bundle, virtual void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
size_t pack_size) override; size_t pack_size) override;
virtual void exec_im2col( virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_matmul( void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override { WorkspaceBundle bundle_thread) override {
...@@ -338,7 +345,7 @@ public: ...@@ -338,7 +345,7 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
}; };
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
...@@ -359,11 +366,13 @@ public: ...@@ -359,11 +366,13 @@ public:
Strategy() = default; Strategy() = default;
void packA_kern(WorkspaceBundle bundle, void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec,
size_t pack_size) override; size_t pack_size) override;
void exec_matmul( void exec_matmul(
...@@ -371,8 +380,10 @@ public: ...@@ -371,8 +380,10 @@ public:
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
...@@ -382,7 +393,7 @@ public: ...@@ -382,7 +393,7 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override { WorkspaceBundle bundle_thread) override {
...@@ -411,26 +422,30 @@ public: ...@@ -411,26 +422,30 @@ public:
Strategy() = default; Strategy() = default;
void packA_kern(WorkspaceBundle bundle, void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec,
size_t pack_size) override; size_t pack_size) override;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_matmul( void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
...@@ -465,7 +480,7 @@ public: ...@@ -465,7 +480,7 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
}; };
template <typename op_ctype, typename op_dtype, template <typename op_ctype, typename op_dtype,
...@@ -487,7 +502,7 @@ public: ...@@ -487,7 +502,7 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
}; };
...@@ -510,7 +525,7 @@ public: ...@@ -510,7 +525,7 @@ public:
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
}; };
#endif #endif
......
...@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription&
matmul_desc,
size_t) { size_t) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
...@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmulparam; matmulparam;
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
size_t packed_per_oc_block_size = size_t packed_per_oc_block_size =
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * round_up(matmul_param.K, matmul_desc.innerblocksize.k) *
matmul_algo->get_inner_block_size().m * matmul_desc.innerblocksize.m * matmul_desc.packa_type_size;
matmul_algo->get_packA_type_size();
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
group_id * packA_group_size + a_panel_offset; group_id * packA_group_size + a_panel_offset;
matmul_param.A_ptr = matmul_param.A_ptr =
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
matmul_algo->get_inner_block_size().m); matmul_desc.innerblocksize.m);
} }
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
...@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0]; size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1]; size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg; size_t oc = param.filter_meta.ocpg;
...@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription&
matmul_desc) {
size_t packA_per_oc_block_size = size_t packA_per_oc_block_size =
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * round_up(matmul_param.K, matmul_desc.innerblocksize.k) *
sparam.oc_tile_size * matmul_algo->get_packA_type_size(); sparam.oc_tile_size * matmul_desc.packa_type_size;
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size + size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size +
ncb_index.ndrange_id[3] * packA_per_oc_block_size; ncb_index.ndrange_id[3] * packA_per_oc_block_size;
......
...@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0]; size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1]; size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg; size_t oc = param.filter_meta.ocpg;
......
...@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>:: ...@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam, fallback::MatrixMulImpl::KernParam,
fallback::MatrixMulImpl::AlgoBase*) { const fallback::MatrixMulImpl::AlgoBase*) {
size_t ow = param.osz[1]; size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg; size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
...@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>:: ...@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/, fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1]; size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg; size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
...@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: ...@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/, fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1]; size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg; size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
...@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_dsec*/,
size_t) { size_t) {
MEGDNN_MARK_USED_VAR(bundle); MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(param); MEGDNN_MARK_USED_VAR(param);
...@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/
) {
MEGDNN_MARK_USED_VAR(bundle); MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(ncb_index); MEGDNN_MARK_USED_VAR(ncb_index);
matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX); matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX);
...@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param); MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
size_t sh = param.filter_meta.stride[0]; size_t sh = param.filter_meta.stride[0];
......
...@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/,
size_t) { size_t) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
...@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/
) {
size_t packA_group_size = size_t packA_group_size =
bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group; bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group;
size_t a_panel_offset = ncb_index.ndrange_id[3] * size_t a_panel_offset = ncb_index.ndrange_id[3] *
...@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param, fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param); MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo); MEGDNN_MARK_USED_VAR(matmul_algo);
size_t sh = param.filter_meta.stride[0]; size_t sh = param.filter_meta.stride[0];
......
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
} // namespace fallback } // namespace fallback
......
...@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, ...@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \ DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} : A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \ WorkspaceBundle get_bundle(const KernSizeParam&) const override; \
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \ kern_naked_t get_kern_naked(const KernSizeParam&) const override; \
...@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, ...@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
void pack_B(const KernParam& kern_param, void* out, size_t x0, \ void pack_B(const KernParam& kern_param, void* out, size_t x0, \
size_t xmax) const override; \ size_t xmax) const override; \
InnerBlockSize get_inner_block_size() const override; \ InnerBlockSize get_inner_block_size() const override; \
size_t get_packA_type_size() const override; MatmulDescription matmul_description() const override;
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
...@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, ...@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
_strategy::UNROLL_K}; \ _strategy::UNROLL_K}; \
} \ } \
\ \
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \ MatrixMulImpl::_algo_name::MatmulDescription \
return sizeof(_packa_type); \ MatrixMulImpl::_algo_name::matmul_description() const { \
MatmulDescription mdesc; \
mdesc.packmode = PackMode(); \
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \
return mdesc; \
} }
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
......
...@@ -104,6 +104,12 @@ public: ...@@ -104,6 +104,12 @@ public:
size_t m, n, k; size_t m, n, k;
}; };
struct MatmulDescription {
PackMode packmode;
InnerBlockSize innerblocksize;
size_t packa_type_size;
};
virtual bool usable(const KernSizeParam&) const = 0; virtual bool usable(const KernSizeParam&) const = 0;
virtual bool preferred(const KernSizeParam&) const { return true; } virtual bool preferred(const KernSizeParam&) const { return true; }
virtual size_t get_workspace(const KernSizeParam&) const = 0; virtual size_t get_workspace(const KernSizeParam&) const = 0;
...@@ -125,11 +131,11 @@ public: ...@@ -125,11 +131,11 @@ public:
virtual InnerBlockSize get_inner_block_size() const { virtual InnerBlockSize get_inner_block_size() const {
megdnn_assert(0); megdnn_assert(0);
}; };
virtual size_t get_packA_type_size() const { megdnn_assert(0); };
bool preferred_reproducible(const KernSizeParam& param, bool preferred_reproducible(const KernSizeParam& param,
bool reproducible = true) { bool reproducible = true) {
return (!reproducible || is_reproducible()) && preferred(param); return (!reproducible || is_reproducible()) && preferred(param);
}; };
virtual MatmulDescription matmul_description() const = 0;
}; };
/** /**
......
...@@ -27,6 +27,7 @@ public: ...@@ -27,6 +27,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
...@@ -46,7 +47,9 @@ public: ...@@ -46,7 +47,9 @@ public:
megdnn_assert(0); megdnn_assert(0);
}; };
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; WorkspaceBundle get_bundle(const KernSizeParam& param) const override;
InnerBlockSize get_inner_block_size() const override { return {8, 16, 1}; }; InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; };
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
}; };
#endif #endif
...@@ -124,6 +127,7 @@ public: ...@@ -124,6 +127,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4)
}; };
#if MEGDNN_X86_WITH_VNNI #if MEGDNN_X86_WITH_VNNI
...@@ -149,6 +153,7 @@ public: ...@@ -149,6 +153,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
}; };
#endif #endif
} // namespace x86 } // namespace x86
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册