From 02abc36ea64e70213b4ad95d11f7540891384de3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Jun 2020 19:59:19 +0800 Subject: [PATCH] fix(mbg/arm_common): fix nchw44-dot misc issue GitOrigin-RevId: f870ad964c075fe55fb1c7ca131680873bec61eb --- .../int8/dot_direct_nchw_nchw44_algo.cpp | 2 +- dnn/src/arm_common/conv_bias/opr_impl.cpp | 4 ++-- .../arm_common/conv_bias_multi_thread.cpp | 21 ++++++++++++------- sdk/load-and-run/src/mgblar.cpp | 6 ++++++ .../include/megbrain/utils/persistent_cache.h | 1 - src/opr/impl/dnn/convolution.cpp | 1 - src/plugin/impl/opr_footprint.cpp | 4 +++- 7 files changed, 25 insertions(+), 14 deletions(-) diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 9d3acdf7b..fb86977d1 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.filter_type.enumv() == DTypeEnum::QuantizedS8 && (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && - (fm.format == param::Convolution::Format::NCHW44); + (fm.format == param::Convolution::Format::NCHW44_DOT); bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && (fh == 2 || fh == 3 || fh == 5 || fh == 7); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index f7edf5956..a813fce80 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; #if __ARM_FEATURE_DOTPROD - AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44; AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; @@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; + AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; #endif AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; @@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { #if __ARM_FEATURE_DOTPROD - direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&ds8_direct_stride1_large_group); direct_algos.emplace_back(&ds8_direct_stride1_small_group); direct_algos.emplace_back(&ds8_direct_stride2_large_group); @@ -107,6 +106,7 @@ public: direct_algos.emplace_back(&du8_direct_stride2_small_group); direct_algos.emplace_back(&ds8_direct_nchw44); + direct_algos.emplace_back(&ds8_direct_nchw_nchw44); #endif direct_algos.emplace_back(&qu8_direct_stride2_large_group); direct_algos.emplace_back(&qu8_direct_stride2_small_group); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 0f9eeef4d..acb05e59d 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { /****************************dot qint8 direct*************************/ #if __ARM_FEATURE_DOTPROD TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { - checker_conv_bias_qint8x8x8( - get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, - true), - handle(), "ARMDOTS8_NCHW_NCHW44"); - checker_conv_bias_qint8x8x8( - get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, - true), - handle(), "ARMDOTS8_NCHW_NCHW44"); + auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, + true); + for (auto&& arg : args) { + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + } + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); + + args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, + true); + for (auto&& arg : args) { + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + } + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 15103ab4f..004adae6b 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) { cb(nchw32); cb(nhwcd4); #undef cb + if (!strcmp(argv[i], "--enable-nchw44-dot")) { + mgb_log_warn("enable-nchw44-dot optimization"); + graph_opt.graph_opt.enable_nchw44_dot(); + continue; + } + if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); diff --git a/src/core/include/megbrain/utils/persistent_cache.h b/src/core/include/megbrain/utils/persistent_cache.h index 2469e5109..bd6ce5a2c 100644 --- a/src/core/include/megbrain/utils/persistent_cache.h +++ b/src/core/include/megbrain/utils/persistent_cache.h @@ -94,7 +94,6 @@ namespace mgb { m_param{param}, m_param_size{param_size} { } - //! build a blob representation to be used as cache key PersistentCache::Blob build_blob() const; }; diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 43e666d95..1d896fa24 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser::get_profile_result( AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), origin_layouts.size(), &origin_param, sizeof(origin_param)}; - { auto&& rst = cache.get(cache_key); if (rst.valid()) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 22111a7db..990aa7879 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, src_shape[1] / group * 2; return hybird_nchwx ? computation : computation * 8; } - if (param.format == Param::Format::NCHW44) { + if (param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_DOT) { //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} if (filter_shape[1] == 1 && filter_shape[2] == 1) { group *= 4; @@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_DOT || param.format == Param::Format::NCHW32) { return eval_conv_computation_nchwx(); } -- GitLab