提交 5697fa16 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn): conv1x1 nchw44 support

GitOrigin-RevId: 29b41ff460b16603c337cf67229b31d455de1cd5
上级 a323f1a4
......@@ -40,7 +40,8 @@ size_t ConvBiasImpl::AlgoConv1x1::get_oc_tile_size_heuristic(
size_t OC = param.filter_meta.ocpg;
if (OH * OW >= 56 * 56 || OC >= 64)
return m_oc_block_size;
return div_ceil(OC, param.nr_threads);
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads);
return round_up<size_t>(oc_block_size_one_thread, 24);
}
size_t ConvBiasImpl::AlgoConv1x1::get_workspace(
......@@ -180,8 +181,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) {
//! only support nchw format
if (opr->param().format != param::ConvBias::Format::NCHW)
if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44)
return false;
size_t FH = param.filter_meta.spatial[0],
......@@ -218,8 +219,12 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param));
bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable &&
if(opr->param().format == param::ConvBias::Format::NCHW44)
matmul_param.format = param::MatrixMul::Format::MK4;
bool matmul_usable = m_matmul_algo->usable(matmul_param);
return matmul_usable &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
......
......@@ -71,15 +71,14 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) {
MEGDNN_MARK_USED_VAR(format);
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4;
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, _packmode>>(); \
_postprocess_mode, _packmode>>(pack_size); \
} \
} \
MIDOUT_END()
......@@ -95,7 +94,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, _packmode>>(); \
_postprocess_mode, _packmode>>(pack_size); \
} \
} \
MIDOUT_END()
......
......@@ -88,6 +88,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
megdnn::PostprocessMode postprocess_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode>
class Conv1x1Strategy : public Conv1x1StrategyBase {
public:
explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {}
void packA(WorkspaceBundle& whole_bundle,
WorkspaceBundle& matmul_bundle,
size_t oc_tile_size,
......@@ -133,6 +135,9 @@ public:
src_ctype* a_panel = reinterpret_cast<src_ctype*>(
reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
bytes_offset_of_a_panel);
matmul_kern_param.LDA *= m_pack_size;
matmul_kern_param.A_ptr = const_cast<src_ctype*>(
ncb_param.filter<src_ctype>(group_id) +
numbers_offset_of_filter);
......@@ -165,6 +170,8 @@ public:
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) =
get_matmul_kern_param(param, OH * OW, OC);
matmul_kern_param.LDB *= m_pack_size;
rep(batch, BATCH) {
rep(g, GROUP) {
if (SH == 2 && SW == 2)
......@@ -273,6 +280,8 @@ public:
matmul_kern_param.C_ptr = matmul_dst;
matmul_kern_param.LDC *= m_pack_size;
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param);
matmul_kern(matmul_kern_param);
......@@ -291,11 +300,14 @@ public:
else
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start));
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
oc_end - oc_start, OH, OW);
(oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size);
}
private:
size_t m_pack_size = 1;
};
class Conv1x1Factory {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册