diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 9d3acdf7bc757879af1f4a8bdcc3df206a733b60..fb86977d17527306a1cbcf396c6a07d52813c25b 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.filter_type.enumv() == DTypeEnum::QuantizedS8 && (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && - (fm.format == param::Convolution::Format::NCHW44); + (fm.format == param::Convolution::Format::NCHW44_DOT); bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && (fh == 2 || fh == 3 || fh == 5 || fh == 7); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index f7edf5956b620176effeddf0d7ee3d316ae51775..a813fce8051681e52af629f972cb24be21cda84e 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; #if __ARM_FEATURE_DOTPROD - AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44; AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; @@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; + AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; #endif AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; @@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { #if __ARM_FEATURE_DOTPROD - direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&ds8_direct_stride1_large_group); direct_algos.emplace_back(&ds8_direct_stride1_small_group); direct_algos.emplace_back(&ds8_direct_stride2_large_group); @@ -107,6 +106,7 @@ public: direct_algos.emplace_back(&du8_direct_stride2_small_group); direct_algos.emplace_back(&ds8_direct_nchw44); + direct_algos.emplace_back(&ds8_direct_nchw_nchw44); #endif direct_algos.emplace_back(&qu8_direct_stride2_large_group); direct_algos.emplace_back(&qu8_direct_stride2_small_group); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 0f9eeef4da41826a9efe9e16a5926bd1afe88648..acb05e59dfd5126ccda1ca23ca688597c8d07d73 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { /****************************dot qint8 direct*************************/ #if __ARM_FEATURE_DOTPROD TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { - checker_conv_bias_qint8x8x8( - get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, - true), - handle(), "ARMDOTS8_NCHW_NCHW44"); - checker_conv_bias_qint8x8x8( - get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, - true), - handle(), "ARMDOTS8_NCHW_NCHW44"); + auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, + true); + for (auto&& arg : args) { + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + } + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); + + args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, + true); + for (auto&& arg : args) { + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + } + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 15103ab4fedac40ab1a78447f13589731ff048db..004adae6bf24801b09be4f18f2c56418780c8ae3 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) { cb(nchw32); cb(nhwcd4); #undef cb + if (!strcmp(argv[i], "--enable-nchw44-dot")) { + mgb_log_warn("enable-nchw44-dot optimization"); + graph_opt.graph_opt.enable_nchw44_dot(); + continue; + } + if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); diff --git a/src/core/include/megbrain/utils/persistent_cache.h b/src/core/include/megbrain/utils/persistent_cache.h index 2469e5109502cfcf276b5f44c1c0bb46ceca0172..bd6ce5a2c0ec724d7f740814df5ef03265ef93c4 100644 --- a/src/core/include/megbrain/utils/persistent_cache.h +++ b/src/core/include/megbrain/utils/persistent_cache.h @@ -94,7 +94,6 @@ namespace mgb { m_param{param}, m_param_size{param_size} { } - //! build a blob representation to be used as cache key PersistentCache::Blob build_blob() const; }; diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 43e666d95f501e088867431a946abe40a7cf9b5c..1d896fa240044a37622e43c2e46b842929ca0e51 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser::get_profile_result( AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), origin_layouts.size(), &origin_param, sizeof(origin_param)}; - { auto&& rst = cache.get(cache_key); if (rst.valid()) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 22111a7dbf7d07f181c770d1fab8d3ffbd8ef431..990aa7879e6b42bcd59638786b5a3baad1233c05 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, src_shape[1] / group * 2; return hybird_nchwx ? computation : computation * 8; } - if (param.format == Param::Format::NCHW44) { + if (param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_DOT) { //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} if (filter_shape[1] == 1 && filter_shape[2] == 1) { group *= 4; @@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_DOT || param.format == Param::Format::NCHW32) { return eval_conv_computation_nchwx(); }