提交 8e9fa80c 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add matmul description for im2col

GitOrigin-RevId: 5bde0b60f0b8102cd8bad14457cf123bc7e6dafa
上级 af3de7e1
......@@ -60,6 +60,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4)
};
class MatrixMulImpl::AlgoF32Gemv final
......@@ -86,6 +87,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
};
#endif
......@@ -207,6 +209,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
};
#if __ARM_FEATURE_DOTPROD
......@@ -234,6 +237,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
#else
......
......@@ -12,6 +12,7 @@
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace arm_common {
......@@ -25,6 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
......@@ -38,6 +40,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase {
......@@ -54,6 +57,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -68,6 +72,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
#endif
......@@ -82,6 +87,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
......
......@@ -49,6 +49,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -71,6 +72,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
};
#endif
#if __ARM_FEATURE_DOTPROD
......@@ -190,6 +192,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
};
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
......
......@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle,
}
//! packA_kern
static void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy, size_t pack_oc_size) {
static void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
size_t pack_oc_size) {
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo,
ncb_index, pack_oc_size);
ncb_index, matmul_desc, pack_oc_size);
}
/*!
......@@ -72,7 +75,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) {
......@@ -111,7 +115,8 @@ public:
//! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index);
matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
......@@ -151,7 +156,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) {
......@@ -191,7 +197,8 @@ public:
//! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index);
matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
......@@ -232,7 +239,8 @@ public:
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
StrategyParam strategyparam,
fallback::ConvBiasImpl::NCBKernIndex ncb_index,
size_t ohw_tile_size, StrategyBase* im2colstrategy) {
......@@ -272,7 +280,8 @@ public:
//! 2.packb and matmul compute
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread,
matmul_param, matmul_algo, ncb_index);
matmul_param, matmul_algo, ncb_index,
matmul_desc);
//! 3.postprocess and copy dst if need
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread);
......@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t padding = 0, packa_size = 0, packa_group_size = 0;
size_t nr_threads = param.nr_threads;
size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
m_matmul_algo->matmul_description();
bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT;
bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA;
size_t oc_tile_size = 0, ohw_tile_size = 0;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
mdesc.innerblocksize.m, mdesc.innerblocksize.n,
mdesc.packmode);
if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, m_matmul_algo->packmode());
auto im2col_kern_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
......@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0)
: wb.get_size(0);
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
m_matmul_algo->packmode());
packa_group_size = 0;
}
......@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}};
bool need_padding = (PH != 0 || PW != 0);
Pack_Mode packmode = m_matmul_algo->packmode();
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
m_matmul_algo->matmul_description();
Pack_Mode packmode = mdesc.packmode;
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! nopack_mode
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
m_matmul_algo->packmode());
}
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
mdesc.innerblocksize.m, mdesc.innerblocksize.n,
mdesc.packmode);
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
......@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
} else if (default_pack) {
packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m);
packa_parallel_times = div_ceil<size_t>(OC, mdesc.innerblocksize.m);
}
auto matmul_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
if (mdesc.packmode == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
bundle_thread = defaultkern.get_thread_bundle(
param, matmul_param, m_matmul_algo, ohw_tile_size,
oc_tile_size);
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) {
} else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) {
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern;
bundle_thread = onlypackakern.get_thread_bundle(
param, matmul_param, m_matmul_algo, ohw_tile_size,
......@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
auto kern_packA = [bundle, matmul_algo = m_matmul_algo,
matmul_param, im2colstrategy,
pack_oc_size = pack_oc_size](
const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
pack_oc_size = pack_oc_size,
mdesc = mdesc](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index,
im2colstrategy, pack_oc_size);
im2colstrategy, mdesc, pack_oc_size);
};
if (default_pack) {
auto kern_compute_default =
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
ohw_tile_size = ohw_tile_size,
strategyparam = strategyparam,
strategyparam = strategyparam, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::DEFAULT>::kerns(
bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index,
ohw_tile_size, im2colstrategy);
matmul_algo, matmul_desc, strategyparam,
ncb_index, ohw_tile_size, im2colstrategy);
};
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
......@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
strategyparam = strategyparam,
ohw_tile_size = ohw_tile_size,
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns(
bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index,
ohw_tile_size, im2colstrategy);
matmul_algo, matmul_desc, strategyparam,
ncb_index, ohw_tile_size, im2colstrategy);
};
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
if (need_padding) {
......@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[bundle, bundle_thread, matmul_param,
matmul_algo = m_matmul_algo,
strategyparam = strategyparam,
ohw_tile_size = ohw_tile_size,
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
Im2colKerns<Pack_Mode::NO_PACK>::kerns(
bundle, bundle_thread, param, matmul_param,
matmul_algo, strategyparam, ncb_index,
ohw_tile_size, im2colstrategy);
matmul_algo, matmul_desc, strategyparam,
ncb_index, ohw_tile_size, im2colstrategy);
};
if (need_padding) {
......@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}
}
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc =
m_matmul_algo->matmul_description();
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) {
if (mdesc.packmode != Pack_Mode::DEFAULT) {
return false;
//! nchw44 hybird mode and channel wise is not support
} else if (param.filter_meta.icpg < 4_z ||
......@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable(
}
size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
m_matmul_algo->packmode());
}
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
mdesc.innerblocksize.m, mdesc.innerblocksize.n,
m_matmul_algo->packmode());
fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param);
......
......@@ -58,8 +58,9 @@ public:
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desec,
size_t pack_size) = 0;
virtual void exec_im2col(
......@@ -67,15 +68,17 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0;
virtual void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) = 0;
virtual void exec_postprocess(
const fallback::ConvBiasImpl::NCBKernParam& param,
......@@ -284,26 +287,30 @@ public:
Strategy() = default;
virtual void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
virtual void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc,
size_t pack_size) override;
virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override {
......@@ -338,7 +345,7 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
......@@ -359,20 +366,24 @@ public:
Strategy() = default;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec,
size_t pack_size) override;
void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
......@@ -382,7 +393,7 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override {
......@@ -411,26 +422,30 @@ public:
Strategy() = default;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec,
size_t pack_size) override;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc
) override;
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
......@@ -465,7 +480,7 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <typename op_ctype, typename op_dtype,
......@@ -487,7 +502,7 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
......@@ -510,7 +525,7 @@ public:
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
#endif
......
......@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription&
matmul_desc,
size_t) {
bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param;
......@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmulparam;
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
size_t packed_per_oc_block_size =
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
matmul_algo->get_inner_block_size().m *
matmul_algo->get_packA_type_size();
round_up(matmul_param.K, matmul_desc.innerblocksize.k) *
matmul_desc.innerblocksize.m * matmul_desc.packa_type_size;
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
group_id * packA_group_size + a_panel_offset;
matmul_param.A_ptr =
const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
matmul_algo->get_inner_block_size().m);
matmul_desc.innerblocksize.m);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
......@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
......@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription&
matmul_desc) {
size_t packA_per_oc_block_size =
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) *
sparam.oc_tile_size * matmul_algo->get_packA_type_size();
round_up(matmul_param.K, matmul_desc.innerblocksize.k) *
sparam.oc_tile_size * matmul_desc.packa_type_size;
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0);
size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size +
ncb_index.ndrange_id[3] * packA_per_oc_block_size;
......
......@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
......
......@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam,
fallback::MatrixMulImpl::AlgoBase*) {
const fallback::MatrixMulImpl::AlgoBase*) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
......@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
......@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
......
......@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_dsec*/,
size_t) {
MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(param);
......@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/
) {
MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(ncb_index);
matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX);
......@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo);
size_t sh = param.filter_meta.stride[0];
......
......@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/,
size_t) {
bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param;
......@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
const fallback::MatrixMulImpl::AlgoBase::
MatmulDescription& /*matmul_desc*/
) {
size_t packA_group_size =
bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group;
size_t a_panel_offset = ncb_index.ndrange_id[3] *
......@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo);
size_t sh = param.filter_meta.stride[0];
......
......@@ -37,6 +37,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
} // namespace fallback
......
......@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \
......@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
void pack_B(const KernParam& kern_param, void* out, size_t x0, \
size_t xmax) const override; \
InnerBlockSize get_inner_block_size() const override; \
size_t get_packA_type_size() const override;
MatmulDescription matmul_description() const override;
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
......@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
_strategy::UNROLL_K}; \
} \
\
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \
return sizeof(_packa_type); \
MatrixMulImpl::_algo_name::MatmulDescription \
MatrixMulImpl::_algo_name::matmul_description() const { \
MatmulDescription mdesc; \
mdesc.packmode = PackMode(); \
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
......
......@@ -104,6 +104,12 @@ public:
size_t m, n, k;
};
struct MatmulDescription {
PackMode packmode;
InnerBlockSize innerblocksize;
size_t packa_type_size;
};
virtual bool usable(const KernSizeParam&) const = 0;
virtual bool preferred(const KernSizeParam&) const { return true; }
virtual size_t get_workspace(const KernSizeParam&) const = 0;
......@@ -125,11 +131,11 @@ public:
virtual InnerBlockSize get_inner_block_size() const {
megdnn_assert(0);
};
virtual size_t get_packA_type_size() const { megdnn_assert(0); };
bool preferred_reproducible(const KernSizeParam& param,
bool reproducible = true) {
return (!reproducible || is_reproducible()) && preferred(param);
};
virtual MatmulDescription matmul_description() const = 0;
};
/**
......
......@@ -27,6 +27,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
......@@ -46,7 +47,9 @@ public:
megdnn_assert(0);
};
WorkspaceBundle get_bundle(const KernSizeParam& param) const override;
InnerBlockSize get_inner_block_size() const override { return {8, 16, 1}; };
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; };
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
};
#endif
......@@ -124,6 +127,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4)
};
#if MEGDNN_X86_WITH_VNNI
......@@ -149,6 +153,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
};
#endif
} // namespace x86
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册