diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index 47d9ae2db93292962a6fa5aeb5cae12abff62bd2..ad3e0f9b3c5a16b160840a99ca4621dfa5ad5073 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -227,28 +227,28 @@ public: "DefaultStrategyType::FLOAT"_hash); } else if (format == param::ConvBias::Format::NCHW44) { -#if MEGDNN_AARCH64 +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 auto matmul_block = matmul_algo->get_inner_block_size(); - //! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse - if (matmul_block.m == 8 && matmul_block.n == 12 && - matmul_block.k == 1 && - param.filter_meta.spatial[0] == 3 && - param.filter_meta.spatial[1] == 3 && - param.filter_meta.stride[0] == 2 && - param.filter_meta.stride[1] == 2 && - !param.filter_meta.should_flip) { - MIDOUT_BEGIN( - megdnn_fallback_im2col_factory_make_strategy, - midout_iv( - "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) { - return std::make_unique< - StrategyFuse8x12x1Nchw44K3x3S2< - float, float, - PostprocessMode::FLOAT>>(); - } - MIDOUT_END(); - return {}; + //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse + if ((matmul_block.m == 8 || matmul_block.m == 4) && + matmul_block.n == 12 && matmul_block.k == 1 && + param.filter_meta.spatial[0] == 3 && + param.filter_meta.spatial[1] == 3 && + param.filter_meta.stride[0] == 2 && + param.filter_meta.stride[1] == 2 && + !param.filter_meta.should_flip) { + MIDOUT_BEGIN( + megdnn_fallback_im2col_factory_make_strategy, + midout_iv( + "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) { + return std::make_unique< + StrategyFuseXx12x1Nchw44K3x3S2< + float, float, + PostprocessMode::FLOAT>>(); } + MIDOUT_END(); + return {}; + } #endif cb1(NCHW44, DEFAULT, dt_float32, dt_float32, @@ -345,10 +345,10 @@ public: "DefaultStrategyType::QINT8x8x32x8"_hash); } else if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { -#if MEGDNN_AARCH64 - auto matmul_block = matmul_algo->get_inner_block_size(); if (format == param::ConvBias::Format::NCHW44) { //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse +#if MEGDNN_AARCH64 + auto matmul_block = matmul_algo->get_inner_block_size(); if (matmul_block.m == 4 && matmul_block.n == 4 && matmul_block.k == 16 && param.filter_meta.spatial[0] == 3 && @@ -368,7 +368,10 @@ public: MIDOUT_END(); return {}; } +#endif } else { +#if MEGDNN_AARCH64 + auto matmul_block = matmul_algo->get_inner_block_size(); //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse if (matmul_block.m == 8 && matmul_block.n == 12 && matmul_block.k == 4 && @@ -389,8 +392,30 @@ public: MIDOUT_END(); return {}; } - } #endif +#if MEGDNN_ARMV7 + auto matmul_block = matmul_algo->get_inner_block_size(); + if (matmul_block.m == 8 && matmul_block.n == 4 && + matmul_block.k == 4 && + param.filter_meta.spatial[0] == 3 && + param.filter_meta.spatial[1] == 3 && + param.filter_meta.stride[0] == 2 && + param.filter_meta.stride[1] == 2 && + !param.filter_meta.should_flip) { + MIDOUT_BEGIN( + megdnn_fallback_im2col_factory_make_strategy, + midout_iv( + "DefaultStrategyType::INT8x8x32_8x4x4_s2"_hash)) { + return std::make_unique< + StrategyFuse8x4x4Nchw44DotK3x3S2< + dt_qint32, dt_qint8, + PostprocessMode::QUANTIZED>>(); + } + MIDOUT_END(); + return {}; + } +#endif + } cb2(NCHW44, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, PostprocessMode::QUANTIZED, diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index 1c3233d60b40210853115d51aa1fdb27ffbc8a12..d52f18e41c1e943ed50b906a6638f0eb4cb4c0fd 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -488,12 +488,12 @@ public: template -class StrategyFuse8x12x1Nchw44K3x3S2 - : public Strategy { public: - StrategyFuse8x12x1Nchw44K3x3S2() = default; + StrategyFuse8x12x4Nchw44Dot() = default; constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PACKA_INDEX = 1; @@ -508,16 +508,15 @@ public: fallback::MatrixMulImpl::KernParam matmul_param, const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; }; - - +#else template -class StrategyFuse8x12x4Nchw44Dot +class StrategyFuse8x4x4Nchw44DotK3x3S2 : public Strategy { public: - StrategyFuse8x12x4Nchw44Dot() = default; + StrategyFuse8x4x4Nchw44DotK3x3S2() = default; constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PACKA_INDEX = 1; @@ -534,6 +533,30 @@ public: }; #endif +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +template +class StrategyFuseXx12x1Nchw44K3x3S2 + : public Strategy { +public: + StrategyFuseXx12x1Nchw44K3x3S2() = default; + + constexpr static size_t BUNDLE_PADDING_INDEX = 0; + constexpr static size_t BUNDLE_PACKA_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; + constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; + + void exec_im2col( + const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; +}; +#endif } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4f8914f1edec17a70c0ccc3851d7184f5f79c44 --- /dev/null +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp @@ -0,0 +1,208 @@ +/** + * \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/fallback/conv_bias/im2col/strategy_base.h" + +#if MEGDNN_ARMV7 +#include +using namespace megdnn; +namespace { + +#define PACKB_ONELINE() \ + int out_index = 0; \ + outptr = output_base; \ + for (; out_index + 3 < block_size; out_index += 4) { \ + std::memcpy(outptr, tmp_output, 16); \ + outptr += ksize4; \ + tmp_output += 4; \ + } \ + \ + if (out_index < block_size) { \ + uint32_t zerobuffer[4] = {0}; \ + size_t out_remain = std::min(block_size - out_index, 4); \ + std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \ + outptr += out_remain; \ + std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \ + } \ + output_base += 4; + +#define STOR_IM2COL_DST() \ + output0[count] = uint32_src[index]; \ + output1[count] = uint32_src[index + 1]; \ + output2[count] = uint32_src[index + 2]; \ + count++; \ + index += SW; + +#define LOAD_AND_STOR_IM2COL_DST() \ + uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]); \ + index += 8; \ + uint32x4_t val_index8 = vdupq_n_u32(uint32_src[index]); \ + uint32x4_t val_2 = vextq_u32(val_01.val[0], val_index8, 1); \ + vst1q_u32(&output0[count], val_01.val[0]); \ + vst1q_u32(&output1[count], val_01.val[1]); \ + vst1q_u32(&output2[count], val_2); \ + count += 4; + +void fuse_packb(const dt_int8* __restrict src, dt_int8* __restrict dst, + dt_int8* __restrict b_panel, const int OW, const int IC, + const int IH, const int IW, const int cur_index, + const int block_size) { + int start_h = cur_index / OW; + int cur_remain_w = cur_index % OW; + int end_h = (cur_index + block_size) / OW; + int end_remain_w = (cur_index + block_size) % OW; + bool same_line = start_h == end_h ? true : false; + size_t newIC = IC / 4; + const uint32_t* uint32_src = + static_cast(static_cast(src)); + uint32_t* output = static_cast(static_cast(dst)); + uint32_t* b_output = static_cast(static_cast(b_panel)); + const int packed_k = newIC * 3 * 3; + const int ksize4 = packed_k * 4; + uint32_t* outptr = b_output; + uint32_t* output_base = b_output; + constexpr int FH = 3; + constexpr int SH = 2; + constexpr int SW = 2; + if (same_line) { + rep(ic, newIC) { + rep(fh, FH) { + uint32_t* output02 = output; + uint32_t* output1 = output + block_size + 1; + + size_t count = 0; + size_t index = 0; + int w = cur_remain_w; + index = (ic * IH + (start_h * SH + fh)) * IW + w * SW; + for (; w + 3 < end_remain_w; w += 4) { + uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]); + vst1q_u32(&output02[count], val_01.val[0]); + vst1q_u32(&output1[count], val_01.val[1]); + count += 4; + index += 8; + } + for (; w < end_remain_w; w++) { + output02[count] = uint32_src[index + 0]; + output1[count] = uint32_src[index + 1]; + count++; + index += SW; + } + output02[count] = uint32_src[index]; + const uint32_t* output_ptr[3]; + output_ptr[0] = output02; + output_ptr[1] = output1; + output_ptr[2] = output02 + 1; + for (int i = 0; i < 3; i++) { + const uint32_t* tmp_output = output_ptr[i]; + PACKB_ONELINE(); + } + } + } + } else { + rep(ic, newIC) { + rep(fh, FH) { + size_t count = 0; + size_t index = 0; + uint32_t* output0 = output; + uint32_t* output1 = output + block_size; + uint32_t* output2 = output1 + block_size; + int w = cur_remain_w; + index = (ic * IH + (SH * start_h + fh)) * IW + SW * w; + for (; w + 3 < OW; w += 4) { + LOAD_AND_STOR_IM2COL_DST() + } + + for (; w < OW; w++) { + STOR_IM2COL_DST() + } + + for (int h = start_h + 1; h < end_h; h++) { + int ow = 0; + index = (ic * IH + (SH * h + fh)) * IW; + for (; ow + 3 < OW; ow += 4) { + LOAD_AND_STOR_IM2COL_DST() + } + + for (; ow < OW; ow++) { + STOR_IM2COL_DST() + } + } + + index = (ic * IH + (SH * end_h + fh)) * IW; + w = 0; + for (; w + 3 < end_remain_w; w += 4) { + LOAD_AND_STOR_IM2COL_DST() + } + + for (; w < end_remain_w; w++) { + STOR_IM2COL_DST() + } + + for (int k = 0; k < 3; k++) { + const uint32_t* tmp_output = output + k * block_size; + PACKB_ONELINE(); + } + } + } + } +} +#undef PACKB_ONELINE +#undef STOR_IM2COL_DST +#undef LOAD_AND_STOR_IM2COL_DST +} // namespace + +template +void StrategyFuse8x4x4Nchw44DotK3x3S2:: + exec_im2col(const WorkspaceBundle& bundle, + const WorkspaceBundle& bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam /*matmul_param*/, + const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { + size_t ow = param.osz[1]; + size_t ic = param.filter_meta.icpg; + size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; + size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; + size_t input_offset = + ih * iw * ic * + (sparam.group_id + param.filter_meta.group * sparam.batch_id) * + sizeof(dt_int8); + + dt_int8* src2 = reinterpret_cast( + reinterpret_cast(bundle.get(BUNDLE_PADDING_INDEX)) + + input_offset); + bool is_phpwzero = param.filter_meta.padding[0] == 0 && + param.filter_meta.padding[1] == 0; + if (is_phpwzero) { + src2 = const_cast( + param.src(sparam.batch_id, sparam.group_id)); + } + dt_int8* b_panel = reinterpret_cast(reinterpret_cast( + bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); + megdnn_assert(ic % 4 == 0, "nchw44dot_dot with ic is not of time 4"); + + int8_t* im2col_dst = + static_cast(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); + + fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, + sparam.output_block_size); +} + +namespace megdnn { + +template class StrategyFuse8x4x4Nchw44DotK3x3S2; +} // namespace megdnn + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp index 96782cdfb3171c89c41f83616af065a87a9268bf..af13d2f0d19b78c24c4148ffdd4daa4aba78155c 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp @@ -10,9 +10,8 @@ */ #include "src/fallback/conv_bias/im2col/strategy_base.h" -#include "src/fallback/convolution/img2col_helper.h" -#if MEGDNN_AARCH64 +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 #include using namespace megdnn; @@ -163,7 +162,7 @@ void fuse_packb(const float* __restrict src, float* __restrict dst, template -void StrategyFuse8x12x1Nchw44K3x3S2:: +void StrategyFuseXx12x1Nchw44K3x3S2:: exec_im2col(const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, const StrategyParam& sparam, @@ -194,14 +193,13 @@ void StrategyFuse8x12x1Nchw44K3x3S2:: float* im2col_dst = static_cast(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); - fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, sparam.output_block_size); } namespace megdnn { -template class StrategyFuse8x12x1Nchw44K3x3S2; } // namespace megdnn diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index c96735f2b4c7543179bfa14863d3a912dce8fe67..3c82e69f812b36997c75daf3d1b49385bbb95127 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1461,6 +1461,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { #undef cb } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT_S2_FUSE) { + UniformIntRNG rng{-50, 50}; + +#define cb(name) \ + checker_conv_bias(get_nchw44_conv_bias_args({3}, 2, false, \ + false, false, false, true), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); \ + + float epsilon = 0.001; +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); +#elif MEGDNN_ARMV7 + cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); +#endif +#undef cb +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { UniformIntRNG rng{-50, 50}; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 7681c9f6be7125e6ec5e51b3df4e9a5d38a55e4d..0431615bebd52224716d1d0d63db97c6b45883db 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -655,6 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); } +#if __ARM_FEATURE_DOTPROD TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { constexpr size_t RUNS = 40; std::vector data_type = { @@ -708,6 +709,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { + constexpr size_t RUNS = 40; + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S, + bool is_nchw = false) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW44_DOT; + auto OH = (H + 2 * P - FS) / static_cast(S) + 1; + auto OW = (W + 2 * P - FS) / static_cast(S) + 1; + TensorShape src = {N, IC / 4, H, W, 4}; + TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4}; + if (group > 1) { + filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4}; + param.sparse = param::ConvBias::Sparse::GROUP; + } + if (is_nchw) { + src = {N, IC, H, W}; + filter = {OC / 4, FS, FS, IC, 4}; + } + TensorShape bias = {1, OC / 4, 1, 1, 4}; + TensorShape dst = {N, OC / 4, OH, OW, 4}; + + SmallVector shapes{src, filter, bias, {}, dst}; + float computations = + (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + bench_case(1, 64, 64, 56, 56, 3, 1, 1, 2); + bench_case(1, 64, 64, 128, 128, 3, 1, 1, 2); + bench_case(1, 64, 64, 256, 256, 3, 1, 1, 2); + bench_case(1, 64, 64, 156, 156, 3, 1, 1, 2); + bench_case(1, 128, 128, 28, 28, 3, 1, 1, 2); + bench_case(1, 256, 256, 14, 14, 3, 1, 1, 2); + bench_case(1, 512, 512, 7, 7, 3, 1, 1, 2); + + bench_case(1, 64, 64, 56, 56, 3, 4, 1, 2); + bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); + bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); + bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); + +} + + +#endif TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { constexpr size_t RUNS = 40; std::vector data_type = {