diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index a9e93694d64540fc128fe4d5c2da2da8f511e36a..da5ccdf632f4298de3206a60e69fef2cd92e9b53 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -49,6 +49,14 @@ namespace { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ + megdnn::arm_common::OpCallerBinary<_op, \ + megdnn::arm_common::VEC_BCAST101x4>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); + #define FOR_NONLINEAR_BINARY(_op) \ megdnn::arm_common:: \ OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ @@ -57,20 +65,26 @@ namespace { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N* OC* OH* OW); -#define FOR_BIAS(_mode) \ - switch (_mode) { \ - case megdnn::BiasMode::NO_BIAS: \ - FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ - break; \ - case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \ - break; \ - case megdnn::BiasMode::BIAS: \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ - break; \ - default: \ - megdnn_throw("no quantized unsupported biasmode"); \ - break; \ +#define FOR_BIAS(_mode) \ + switch (_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ + } else { \ + megdnn_assert(pack_oc_size == 4, \ + "Only support nchw44 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + } \ + break; \ + case megdnn::BiasMode::BIAS: \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ + break; \ + default: \ + megdnn_throw("no quantized unsupported biasmode"); \ + break; \ } #define FOR_NONLINEAR(_caller) \ @@ -129,6 +143,7 @@ struct PostProcess { #undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_BINARY_BROADCAST +#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 #undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR @@ -187,6 +202,8 @@ struct PostProcess { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ + megdnn_assert(pack_oc_size == 4, \ + "Only support nchw44 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ } \ break; \ diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 088b8f9717fa25eda94e838a259fbf27c9e5d459..95e28d0104980378eae9b9cdb661188011c478a3 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -216,14 +216,18 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, param.nonlineMode != megdnn::NonlineMode::IDENTITY) return false; + if (opr->param().format == param::ConvBias::Format::NCHW44) { + //! nchw44 hybird mode and channel wise is not support + if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || + param.filter_meta.ocpg == 1) { + return false; + } + } + size_t OH = param.osz[0]; size_t OW = param.osz[1]; - MatrixMulImpl::KernSizeParam matmul_param = - get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); - - if(opr->param().format == param::ConvBias::Format::NCHW44) - matmul_param.format = param::MatrixMul::Format::MK4; - + MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param( + param, OH * OW, get_oc_tile_size_heuristic(param)); bool matmul_usable = m_matmul_algo->usable(matmul_param); return matmul_usable && diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp index 52f618000b0d59089bcf088c53ea909a9912e3b2..845d5c76faf544c810db73aef807477821284581 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp @@ -22,6 +22,20 @@ namespace conv1x1 { namespace { +size_t get_format_pack_size(param::ConvBias::Format format) { + switch(format){ + case param::ConvBias::Format::NCHW44: + case param::ConvBias::Format::NCHW4: + return 4_z; + case param::ConvBias::Format::NCHW88: + return 8_z; + case param::ConvBias::Format::NCHW: + return 1_z; + default: + megdnn_throw("unknow pack size of the format"); + } +} + struct StrategyHashParam { ConvBiasImpl::NCBKernSizeParam param; param::ConvBias::Format format; @@ -71,7 +85,7 @@ std::unique_ptr create_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format format) { - size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; + size_t pack_size = get_format_pack_size(format); #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ midout_iv(_midout_tag)) { \ diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h index 2030d02c96e03fd767ba368bd5726620bfb4cdbf..8e0456de69346755d2e89b5b223a535e8fd72c61 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h @@ -41,19 +41,25 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( param.dst_type.enumv() == DTypeEnum::QuantizedS8) || (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); + size_t pack_c_size = 1_z; + auto format = param::MatrixMul::Format::DEFAULT; + if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ + pack_c_size = 4_z; + format = param::MatrixMul::Format::MK4; + } return {param.filter_type, param.src_type, is_dst_8bit ? param.bias_type : param.dst_type, M, N, K, - LDA, - LDB, - LDC, + LDA * pack_c_size, + LDB * pack_c_size, + LDC * pack_c_size, false, false, param::MatrixMul::ComputeMode::DEFAULT, - param::MatrixMul::Format::DEFAULT}; + format}; } } // namespace @@ -137,9 +143,7 @@ public: src_ctype* a_panel = reinterpret_cast( reinterpret_cast(whole_bundle.get(0)) + bytes_offset_of_a_panel); - - matmul_kern_param.LDA *= m_pack_size; - + matmul_kern_param.A_ptr = const_cast( ncb_param.filter(group_id) + numbers_offset_of_filter); @@ -172,7 +176,6 @@ public: static_cast(matmul_kern_param) = get_matmul_kern_param(param, OH * OW, OC); - matmul_kern_param.LDB *= m_pack_size; rep(batch, BATCH) { rep(g, GROUP) { @@ -282,8 +285,6 @@ public: matmul_kern_param.C_ptr = matmul_dst; - matmul_kern_param.LDC *= m_pack_size; - if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); matmul_kern(matmul_kern_param); @@ -295,14 +296,15 @@ public: //! do postprocess void* bias_ptr = nullptr; - if (param.bias_mode == megdnn::BiasMode::BIAS) + if (param.bias_mode == megdnn::BiasMode::BIAS) { bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + numbers_of_ncb_dst_offset)); - else + } else { bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + oc_start)); - + } + PostProcess::run( matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, param.nonlineMode, param.bias_type, param.dst_type, 1_z, diff --git a/dnn/src/fallback/conv_bias/winograd/winograd.h b/dnn/src/fallback/conv_bias/winograd/winograd.h index f4d4f25849e8f136ac11418f76512de4372b73af..ae3f1f9960291746755b53c14e132d892eeb0ad9 100644 --- a/dnn/src/fallback/conv_bias/winograd/winograd.h +++ b/dnn/src/fallback/conv_bias/winograd/winograd.h @@ -137,8 +137,8 @@ class ConvBias { sizeof(output_compute_type) * std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE); - size_t matmul_workspace_size = - matmul_algo->get_workspace(get_matmul_kern_param(param)); + size_t matmul_workspace_size = matmul_algo->get_workspace( + get_matmul_kern_param(param, m_unit_oc_size)); //! compute workspace is independent and separated as far as possible //! in case of false cache line sharing @@ -384,7 +384,7 @@ public: get_wbundle_compute(param, matmul_algo); fallback::MatrixMulImpl::KernParam matmul_param; static_cast(matmul_param) = - get_matmul_kern_param(param); + get_matmul_kern_param(param, m_unit_oc_size); Strategy strategy = m_strategy; size_t unit_tile_size = m_unit_tile_size; @@ -450,21 +450,24 @@ public: } fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( - const NCBKernSizeParam& param) const { + const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const { size_t M = 0; size_t N = 0; size_t K = 0; size_t LDA = 0, LDB = 0, LDC = 0; + if (nr_oc_in_unit == 0) { + nr_oc_in_unit = param.filter_meta.ocpg; + } if (format == param::MatrixMul::Format::DEFAULT) { M = m_unit_tile_size; - N = param.filter_meta.ocpg; + N = nr_oc_in_unit; K = param.filter_meta.icpg; LDA = K; LDB = N; LDC = N; } else { - M = param.filter_meta.ocpg; + M = nr_oc_in_unit; N = m_unit_tile_size; K = param.filter_meta.icpg; megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu", diff --git a/dnn/src/x86/conv_bias/postprocess_helper.h b/dnn/src/x86/conv_bias/postprocess_helper.h index 6cab32bd1ecdeade5953637eaae72304ba34dd99..04602822c441f66251e078fa6e5f00ce358fc243 100644 --- a/dnn/src/x86/conv_bias/postprocess_helper.h +++ b/dnn/src/x86/conv_bias/postprocess_helper.h @@ -126,6 +126,8 @@ struct PostProcess { DType bias_type, DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(pack_oc_size); + megdnn_assert(pack_oc_size == 1, + "PostProcess only support nchw in x86"); megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode::ADD; if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { @@ -149,38 +151,6 @@ struct PostProcess { } }; -template -struct PostProcess { - static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::ConvBiasForward::BiasMode bias_mode, - megdnn::param::ConvBias::NonlineMode nonlineMode, - DType bias_type, DType dst_type, size_t N, size_t OC, - size_t OH, size_t OW, size_t pack_oc_size=1) { - MEGDNN_MARK_USED_VAR(pack_oc_size); - megdnn::param::Elemwise::Mode elem_mode = - megdnn::param::Elemwise::Mode::ADD; - if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { - switch (nonlineMode) { - BIAS_CASE(RELU); - BIAS_CASE(SIGMOID); - BIAS_CASE(H_SWISH); - IDENTITY_CASE(IDENTITY); - DEFAULT_CASE; - } - } else { - switch (nonlineMode) { - NOBIAS_CASE(RELU); - NOBIAS_CASE(SIGMOID); - NOBIAS_CASE(H_SWISH); - IDENTITY_CASE(IDENTITY); - DEFAULT_CASE; - } - } - - FOR_BIAS(bias_mode); - } -}; - template struct PostProcess { static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, @@ -297,6 +267,8 @@ struct PostProcess { DType bias_type, DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(pack_oc_size); + megdnn_assert(pack_oc_size == 1, + "PostProcess only support nchw nchw in x86"); megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode::ADD; if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index c3a6ebf567483e4cdf771fa0a7f9644bf753aed7..948528a2d3f176ffd854e695907cb0c0853f13fc 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1297,6 +1297,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { #endif } +#if MEGDNN_AARCH64 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); + check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); +} +#endif + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); + std::vector args_of_4; + for (auto&& arg : args) { + if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) { + args_of_4.push_back(arg); + } + } +#if MEGDNN_AARCH64 + check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24"); +#elif MEGDNN_ARMV7 + check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48"); +#endif +} + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { using namespace conv_bias;