diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 6d8c238420499bb84b88841b1c9f20a0d5583eeb..e694520eed5c3512a74998d863989f8f70b427df 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -171,7 +171,10 @@ SmallVector ConvBiasImpl::AlgoConv1x1::dispatch_kerns( if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT || pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) { - ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}}); + //! if enable filter preprocess kern_packA should not dispatch + if (!is_enable_filter_preprocess(param)) { + ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}}); + } if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) { ret_kern.push_back({kern_packB, {1}}); } @@ -200,6 +203,13 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, return false; } + //! only matmul's packmode is packa or default support weight preprocess + if (is_enable_filter_preprocess(param) && + (m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK)) { + return false; + } + if (param.src_type.enumv() != DTypeEnum::Int8 && param.src_type.enumv() != DTypeEnum::QuantizedS8 && param.src_type.enumv() != DTypeEnum::Quantized8Asymm && @@ -253,6 +263,122 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, return false; } +SmallVector +ConvBiasImpl::AlgoConv1x1::deduce_preprocessed_filter_layout( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_fallback_conv1x1, + midout_iv( + "ConvBiasImpl::AlgoConv1x1::deduce_preprocessed_filter_layout"_hash)) { + fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc = + m_matmul_algo->matmul_description(); + bool default_pack = matmul_desc.packmode == + MatrixMulImpl::AlgoBase::PackMode::DEFAULT; + bool only_packA = matmul_desc.packmode == + MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA; + //! only support default_pack and only_packa mode + if (matmul_desc.packmode == + MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { + return {}; + } + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); + + auto matmul_param = utils::get_matmul_kern_param(param, OH * OW, + compt_oc_block_size); + + WorkspaceBundle wb(nullptr, {}); + if (default_pack) { + Conv1x1Kerns dispatcher; + wb = dispatcher.get_bundle(param, matmul_param, m_matmul_algo, + compt_oc_block_size); + } else if (only_packA) { + Conv1x1Kerns + dispatcher; + wb = dispatcher.get_bundle(param, matmul_param, m_matmul_algo, + compt_oc_block_size); + } + + size_t GROUP = param.filter_meta.group; + SmallVector preprocessed_layouts; + preprocessed_layouts.push_back( + {{GROUP, wb.get_size(0)}, dtype::Int8()}); + return preprocessed_layouts; + } + MIDOUT_END(); + return {}; +} + +SmallVector +ConvBiasImpl::AlgoConv1x1::dispatch_preprocess_kerns( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_fallback_conv1x1, + midout_iv( + "ConvBiasImpl::AlgoConv1x1::dispatch_preprocess_kerns"_hash)) { + SmallVector ret_kern; + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + size_t OC = param.filter_meta.ocpg; + size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); + size_t GROUP = param.filter_meta.group; + size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size); + + auto matmul_param = utils::get_matmul_kern_param(param, OH * OW, + compt_oc_block_size); + WorkspaceBundle whole_bundle = {nullptr, {}}; + WorkspaceBundle matmul_bundle = {nullptr, {}}; + auto pack_mode = m_matmul_algo->packmode(); + if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) { + MIDOUT_BEGIN(megdnn_fallback_conv1x1, + midout_iv("get_defaul_matmul_packmode_bundle"_hash)) { + Conv1x1Kerns + dispatcher; + whole_bundle = dispatcher.get_bundle(param, matmul_param, + m_matmul_algo, + compt_oc_block_size); + matmul_bundle = m_matmul_algo->get_bundle(matmul_param); + } + MIDOUT_END(); + } else if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) { + MIDOUT_BEGIN( + megdnn_fallback_conv1x1, + midout_iv("get_onlypacka_matmul_packmode_bundle"_hash)) { + Conv1x1Kerns + dispatcher; + whole_bundle = dispatcher.get_bundle(param, matmul_param, + m_matmul_algo, + compt_oc_block_size); + matmul_bundle = m_matmul_algo->get_bundle(matmul_param); + } + MIDOUT_END(); + } else { + //! if nopack return null so that OprWeightPreprocessProxy can run + //! with nopack mode + return {}; + } + + Conv1x1StrategyBase* conv1x1_strategy = + Conv1x1Factory::make_conv1x1_strategy(param, pack_mode, + param.filter_meta.format); + + auto kern_packA = [this, whole_bundle, matmul_bundle, param, + compt_oc_block_size, conv1x1_strategy]( + const NCBKernParam& ncb_param, + const NCBKernIndex& ncb_index) mutable { + conv1x1_strategy->packA(whole_bundle, matmul_bundle, + compt_oc_block_size, this->m_matmul_algo, + param, ncb_param, std::move(ncb_index)); + }; + + ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}}); + return ret_kern; + } + MIDOUT_END(); + return {}; +} + bool ConvBiasImpl::AlgoConv1x1::is_preferred( const NCBKernSizeParam& param) const { size_t OH = param.osz[0]; diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.h b/dnn/src/fallback/conv_bias/conv1x1/algos.h index f7bab4b0748f168927dfa51e432b80e80c5a7966..b42671003073358cedc977fefcef1e799b037d38 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.h +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.h @@ -41,6 +41,15 @@ public: const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam&) const override; + + SmallVector deduce_preprocessed_filter_layout( + const NCBKernSizeParam& param) const override; + size_t get_preprocess_workspace( + const NCBKernSizeParam& /*param*/) const override { + return 0; + } + SmallVector dispatch_preprocess_kerns( + const NCBKernSizeParam& param) const override; protected: size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h index e204f90d897f8dbb6af29360adf7e606a10dd7fd..30bb6147375b4938b7175b7faf657442b32b050e 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h @@ -35,7 +35,6 @@ public: auto matmul_bundle = matmul_algo->get_bundle(matmul_param); auto thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2), oc_tile_size); - //! size per thread size_t all_threads_bytes = thread_bundle.total_size_in_bytes() * param.nr_threads; @@ -44,7 +43,9 @@ public: size_t packa_bytes_per_oc_tile = matmul_bundle.get_size(0); size_t oc_tiles_per_group = div_ceil(OC, oc_tile_size); size_t all_packa_bytes = - packa_bytes_per_oc_tile * oc_tiles_per_group * GROUP; + is_enable_filter_preprocess(param) + ? 0 + : packa_bytes_per_oc_tile * oc_tiles_per_group * GROUP; if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) return WorkspaceBundle{nullptr, diff --git a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h index 5fe94f4db8f7e46129d7f3e4e0e484f293ffa66b..cc173cf5085b59a1373948301dee45d3962847f0 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h +++ b/dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h @@ -106,9 +106,14 @@ public: size_t numbers_offset_of_filter = oc_tile_size * IC * oc_tile_id_in_group; - src_ctype* a_panel = reinterpret_cast( - reinterpret_cast(whole_bundle.get(0)) + - bytes_offset_of_a_panel); + int8_t* tmp_ptr = + is_enable_filter_preprocess(param) + ? static_cast( + param.preprocessed_filter->tensors[0].raw_ptr) + : static_cast(whole_bundle.get(0)); + + src_ctype* a_panel = + reinterpret_cast(tmp_ptr + bytes_offset_of_a_panel); matmul_kern_param.A_ptr = const_cast( ncb_param.filter(group_id) + @@ -206,8 +211,14 @@ public: size_t bytes_offset_of_a_panel = group_id * packa_bytes_per_group + oc_tile_id_in_group * packa_bytes_per_oc_tile; - int8_t* a_panel = reinterpret_cast(whole_bundle.get(0)) + - bytes_offset_of_a_panel; + + int8_t* tmp_ptr = + is_enable_filter_preprocess(param) + ? static_cast( + param.preprocessed_filter->tensors[0].raw_ptr) + : static_cast(whole_bundle.get(0)); + + int8_t* a_panel = tmp_ptr + bytes_offset_of_a_panel; size_t bytes_offset_of_b_panel = batch_id * packb_bytes_per_group * GROUP + diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 5300f1b41baf707c9998510378643fc890babac2..4ce397d67d107bf0d469f3aa15761574139a373b 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -2724,7 +2724,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); +#define cb(name) \ + check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ + dtype::Float32(), dtype::Float32(), \ + dtype::Float32(), dtype::Float32(), name); + +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_F32K8X12X1:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_F32:48"); +#endif +#undef cb +} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { using namespace conv_bias; std::vector args = @@ -2741,7 +2756,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); } - +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32_PREPROCESS) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); +#define cb(name) \ + check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ + dtype::Float32(), dtype::Float32(), \ + dtype::Float32(), dtype::Float32(), name); +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); +#endif +#undef cb +} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { using namespace conv_bias; std::vector args = @@ -2780,6 +2809,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); + NormalRNG rng(1); +#if MEGDNN_AARCH64 + check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, + "CONV1x1:AARCH64_F16_K8X24X1:48"); +#elif MEGDNN_ARMV7 + check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, + dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, + "CONV1x1:AARCH32_F16_K4X16X1:24"); +#endif +} + #endif TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { @@ -2814,6 +2859,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { "CONV1x1_GEMV"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; + std::vector args = + get_conv_bias_1x1_args(false, false, true, true); +#define cb(name) \ + check_conv_bias_preprocess( \ + args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); +#else + cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48"); +#endif +#undef cb +} + + #if MEGDNN_AARCH64 || MEGDNN_ARMV7 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { UniformIntRNG rng{-50, 50}; @@ -2849,6 +2919,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { dtype::Quantized8Asymm(50.3f, (uint8_t)120), "CONV1x1_GEMV"); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM_PREPROCESS) { + UniformIntRNG rng{-50, 50}; + std::vector args = + get_conv_bias_1x1_args(false, false, true, true); +#define cb(name) \ + check_conv_bias_preprocess(args, handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), \ + dtype::Quantized8Asymm(50.3f, (uint8_t)120), \ + name); + float epsilon = 0.001; +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); +#else + cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); +#endif +#elif MEGDNN_ARMV7 + epsilon = 1; + cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48"); +#endif +#undef cb +} + #endif #if MEGDNN_AARCH64 || MEGDNN_ARMV7 @@ -2887,6 +2983,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { + NormalRNG rng(128.f); + float epsilon = 0.001; + std::vector args = get_conv_bias_1x1_args(true, true); +#define cb(name) \ + check_conv_bias_preprocess(args, handle(), &rng, epsilon, \ + dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ + dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ + dtype::QuantizedS32(1.2 * 1.3), {}, name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); +#else + cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); +#endif + cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); +#endif +#undef cb + +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; @@ -2924,6 +3046,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, "CONV1x1_GEMV"); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; + std::vector args = get_conv_bias_1x1_args(true, true); +#define cb(name) \ + check_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, \ + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ + name); + +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); + cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); + cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); + cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test +#endif +#undef cb +} + #endif TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { @@ -2959,6 +3103,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(true, true); + +#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); +#else + cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); + cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); +#endif +#elif MEGDNN_ARMV7 +#if __ARM_FEATURE_DOTPROD + cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); +#endif + cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); +#endif + +#if MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48"); +#endif +#undef cb +} + #ifndef __ARM_FEATURE_DOTPROD TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { using namespace conv_bias; @@ -2988,6 +3158,36 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { #endif #undef cb } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, true, true); + +#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); + +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24"); +#endif +#undef cb + + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; +#define cb(name) \ + check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); +#if MEGDNN_AARCH64 + cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); +#elif MEGDNN_ARMV7 + cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24"); +#endif +#undef cb +} + #endif TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index e3ec400104c05bacacb28ee962d2b622c57164ab..40a4efd39518f7024598a6b3fff2f5318c342526 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -1321,6 +1321,31 @@ void checker_conv_bias(std::vector args, Handle* handle, {arg.src, arg.filter, arg.bias, {}, {}}); } } + +void checker_conv_bias_preprocess(std::vector args, Handle* handle, + RNG* rng, float epsilon, DType type0, DType type1, + DType type2, DType type3, const char* algo_name) { + using namespace conv_bias; + + Checker> checker( + handle); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name)); + checker.set_dtype(0, type0); + checker.set_dtype(1, type1); + checker.set_dtype(2, type2); + checker.set_dtype(4, type3); + checker.set_epsilon(epsilon); + if (NULL != rng) { + checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng); + } + for (auto&& arg : args) { + checker.set_param(arg.param).execs( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + + } // namespace #if MEGDNN_X86_WITH_MKL @@ -1330,14 +1355,32 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_PACKA) { check_conv_bias(args, handle(), "CONV1x1:X86_F32_MKL_PACKA:24"); } +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_PACKA_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); + checker_conv_bias_preprocess(args, handle(), nullptr, 0.001, + dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}, + "CONV1x1:X86_F32_MKL_PACKA:24"); +} + TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) { using namespace conv_bias; std::vector args = get_conv_bias_1x1_args(false, false); check_conv_bias(args, handle(), "CONV1x1:X86_F32_BLAS:48"); } + +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS_NOPACK_REPROCESS) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); + checker_conv_bias_preprocess(args, handle(), nullptr, 0.001, + dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{}, + "CONV1x1:X86_F32_BLAS:24"); +} #endif -TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { using namespace conv_bias; UniformIntRNG rng{-50, 50}; float epsilon = 0.001; @@ -1374,6 +1417,38 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) { dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, "CONV1x1:X86_INT8X8X16_SSE"); } + +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { + using namespace conv_bias; + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; + std::vector args = get_conv_bias_1x1_args(true, true); +#if MEGDNN_X86_WITH_VNNI + if (x86::is_supported(x86::SIMDType::VNNI)) { + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_VNNI:24"); + } +#endif + if (x86::is_supported(x86::SIMDType::AVX2)) { + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_AVX2_4X16X2:24"); + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, + "CONV1x1:X86_INT8X8X16_AVX2"); + } + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_SSE_4X8X2:48"); + checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, + "CONV1x1:X86_INT8X8X16_SSE"); +} + /************************* End Conv1x1 PackA ************************/ #endif