diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 995540de8d47e3f49550fcc98211224ba8e220cd..bad00033313b6fa348982f0a33841ede01d0c6d0 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -40,7 +40,8 @@ size_t ConvBiasImpl::AlgoConv1x1::get_oc_tile_size_heuristic( size_t OC = param.filter_meta.ocpg; if (OH * OW >= 56 * 56 || OC >= 64) return m_oc_block_size; - return div_ceil(OC, param.nr_threads); + size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); + return round_up(oc_block_size_one_thread, 24); } size_t ConvBiasImpl::AlgoConv1x1::get_workspace( @@ -180,8 +181,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { - //! only support nchw format - if (opr->param().format != param::ConvBias::Format::NCHW) + if (opr->param().format != param::ConvBias::Format::NCHW && + opr->param().format != param::ConvBias::Format::NCHW44) return false; size_t FH = param.filter_meta.spatial[0], @@ -218,8 +219,12 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); - bool matmulusable = m_matmul_algo->usable(matmul_param); - return matmulusable && + if(opr->param().format == param::ConvBias::Format::NCHW44) + matmul_param.format = param::MatrixMul::Format::MK4; + + bool matmul_usable = m_matmul_algo->usable(matmul_param); + + return matmul_usable && (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp index 05322417c0d777f3d44c9c3e75a984ec5759869c..52f618000b0d59089bcf088c53ea909a9912e3b2 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp @@ -71,33 +71,32 @@ std::unique_ptr create_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format format) { - MEGDNN_MARK_USED_VAR(format); - -#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ - MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ - midout_iv(_midout_tag)) { \ - if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ - return std::make_unique< \ - Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ - _postprocess_mode, _packmode>>(); \ - } \ - } \ + size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; +#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ + MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ + midout_iv(_midout_tag)) { \ + if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ + return std::make_unique< \ + Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ + _postprocess_mode, _packmode>>(pack_size); \ + } \ + } \ MIDOUT_END() -#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ - _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ - MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ - midout_iv(_midout_tag)) { \ - if (param.filter_type.enumv() == param.src_type.enumv() && \ - param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ - param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ - return std::make_unique< \ - Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ - DTypeTrait<_i_bias_type>::ctype, \ - DTypeTrait<_i_dst_type>::ctype, \ - _postprocess_mode, _packmode>>(); \ - } \ - } \ +#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ + _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ + MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ + midout_iv(_midout_tag)) { \ + if (param.filter_type.enumv() == param.src_type.enumv() && \ + param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ + param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ + return std::make_unique< \ + Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ + DTypeTrait<_i_bias_type>::ctype, \ + DTypeTrait<_i_dst_type>::ctype, \ + _postprocess_mode, _packmode>>(pack_size); \ + } \ + } \ MIDOUT_END() switch (pack_mode) { diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h index f0904ec833e12f11c551be1ee7c8c66c940d27a0..65ca322de2adeb452d02c4eaf824360dde3f95fe 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h @@ -88,6 +88,8 @@ template class Conv1x1Strategy : public Conv1x1StrategyBase { public: + explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {} + void packA(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, size_t oc_tile_size, @@ -133,6 +135,9 @@ public: src_ctype* a_panel = reinterpret_cast( reinterpret_cast(whole_bundle.get(0)) + bytes_offset_of_a_panel); + + matmul_kern_param.LDA *= m_pack_size; + matmul_kern_param.A_ptr = const_cast( ncb_param.filter(group_id) + numbers_offset_of_filter); @@ -165,6 +170,8 @@ public: static_cast(matmul_kern_param) = get_matmul_kern_param(param, OH * OW, OC); + matmul_kern_param.LDB *= m_pack_size; + rep(batch, BATCH) { rep(g, GROUP) { if (SH == 2 && SW == 2) @@ -273,6 +280,8 @@ public: matmul_kern_param.C_ptr = matmul_dst; + matmul_kern_param.LDC *= m_pack_size; + if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); matmul_kern(matmul_kern_param); @@ -291,11 +300,14 @@ public: else bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + oc_start)); + PostProcess::run( matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, param.nonlineMode, param.bias_type, param.dst_type, 1_z, - oc_end - oc_start, OH, OW); + (oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size); } +private: + size_t m_pack_size = 1; }; class Conv1x1Factory {