From 6b2760dd72151e7eb444d4e32b9a88123123cdc1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Jun 2020 18:46:17 +0800 Subject: [PATCH] feat(dnn/fallback): add float32 nchw44 fuse packb 3x3 s2 GitOrigin-RevId: 3b664bb4f578f5e3f2c36fc963217e37676c9b78 --- dnn/src/fallback/conv_bias/im2col/factory.h | 71 ++++++ .../fallback/conv_bias/im2col/strategy_base.h | 69 ++++++ .../im2col/strategy_default_nchw44.cpp | 15 ++ .../conv_bias/im2col/strategy_fuse_nchw44.cpp | 230 ++++++++++++++++++ .../im2col/strategy_fuse_nchw44_dot.cpp | 204 ++++++++++++++++ .../im2col/strategy_fuse_nchw44_fp32_s2.cpp | 209 ++++++++++++++++ .../arm_common/conv_bias_multi_thread.cpp | 11 +- .../conv_bias_multi_thread_benchmark.cpp | 60 +++++ 8 files changed, 868 insertions(+), 1 deletion(-) create mode 100644 dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index 8915783ca..47d9ae2db 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -226,6 +226,31 @@ public: PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); } else if (format == param::ConvBias::Format::NCHW44) { + +#if MEGDNN_AARCH64 + auto matmul_block = matmul_algo->get_inner_block_size(); + //! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse + if (matmul_block.m == 8 && matmul_block.n == 12 && + matmul_block.k == 1 && + param.filter_meta.spatial[0] == 3 && + param.filter_meta.spatial[1] == 3 && + param.filter_meta.stride[0] == 2 && + param.filter_meta.stride[1] == 2 && + !param.filter_meta.should_flip) { + MIDOUT_BEGIN( + megdnn_fallback_im2col_factory_make_strategy, + midout_iv( + "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) { + return std::make_unique< + StrategyFuse8x12x1Nchw44K3x3S2< + float, float, + PostprocessMode::FLOAT>>(); + } + MIDOUT_END(); + return {}; + } +#endif + cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, "DefaultStrategyTypeNCHW44::FLOAT"_hash); @@ -320,6 +345,52 @@ public: "DefaultStrategyType::QINT8x8x32x8"_hash); } else if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { +#if MEGDNN_AARCH64 + auto matmul_block = matmul_algo->get_inner_block_size(); + if (format == param::ConvBias::Format::NCHW44) { + //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse + if (matmul_block.m == 4 && matmul_block.n == 4 && + matmul_block.k == 16 && + param.filter_meta.spatial[0] == 3 && + param.filter_meta.spatial[1] == 3 && + param.filter_meta.stride[0] == 1 && + param.filter_meta.stride[1] == 1 && + !param.filter_meta.should_flip) { + MIDOUT_BEGIN( + megdnn_fallback_im2col_factory_make_strategy, + midout_iv( + "DefaultStrategyType::INT8x8x32_4x4x16"_hash)) { + return std::make_unique< + StrategyFuse4x4x16Nchw44< + dt_qint32, dt_qint8, + PostprocessMode::QUANTIZED>>(); + } + MIDOUT_END(); + return {}; + } + } else { + //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse + if (matmul_block.m == 8 && matmul_block.n == 12 && + matmul_block.k == 4 && + param.filter_meta.spatial[0] == 3 && + param.filter_meta.spatial[1] == 3 && + param.filter_meta.stride[0] == 1 && + param.filter_meta.stride[1] == 1 && + !param.filter_meta.should_flip) { + MIDOUT_BEGIN( + megdnn_fallback_im2col_factory_make_strategy, + midout_iv( + "DefaultStrategyType::INT8x8x32_8x12x4"_hash)) { + return std::make_unique< + StrategyFuse8x12x4Nchw44Dot< + dt_qint32, dt_qint8, + PostprocessMode::QUANTIZED>>(); + } + MIDOUT_END(); + return {}; + } + } +#endif cb2(NCHW44, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, PostprocessMode::QUANTIZED, diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index bd1cf1999..0ce050ddb 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -445,6 +445,75 @@ public: THREAD_BUNDLE_BIAS_INDEX); } }; +#if MEGDNN_AARCH64 +template +class StrategyFuse4x4x16Nchw44 + : public Strategy { +public: + StrategyFuse4x4x16Nchw44() = default; + + constexpr static size_t BUNDLE_PADDING_INDEX = 0; + constexpr static size_t BUNDLE_PACKA_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; + constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; + + void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; +}; + +template +class StrategyFuse8x12x1Nchw44K3x3S2 + : public Strategy { +public: + StrategyFuse8x12x1Nchw44K3x3S2() = default; + + constexpr static size_t BUNDLE_PADDING_INDEX = 0; + constexpr static size_t BUNDLE_PACKA_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; + constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; + + void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; +}; + + +template +class StrategyFuse8x12x4Nchw44Dot + : public Strategy { +public: + StrategyFuse8x12x4Nchw44Dot() = default; + + constexpr static size_t BUNDLE_PADDING_INDEX = 0; + constexpr static size_t BUNDLE_PACKA_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; + constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; + + void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; +}; +#endif + } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp index 8b305d74e..dc61b814e 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp @@ -14,6 +14,9 @@ #include "src/x86/conv_bias/postprocess_helper.h" #endif +#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#include "src/arm_common/conv_bias/postprocess_helper.h" +#endif using namespace megdnn; #if MEGDNN_X86 @@ -101,11 +104,23 @@ void Strategy + +using namespace megdnn; + +namespace { +#define TRANS_AND_STORE(input0, input1, input2, input3) \ + { \ + auto tmp01 = vzipq_s32(input0, input1); \ + auto tmp23 = vzipq_s32(input2, input3); \ + auto dst0 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \ + vreinterpretq_s64_s32(tmp23.val[0])); \ + auto dst1 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \ + vreinterpretq_s64_s32(tmp23.val[0])); \ + auto dst2 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \ + vreinterpretq_s64_s32(tmp23.val[1])); \ + auto dst3 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \ + vreinterpretq_s64_s32(tmp23.val[1])); \ + vst1q_s32(dst, vreinterpretq_s32_s64(dst0)); \ + vst1q_s32(dst + 4, vreinterpretq_s32_s64(dst1)); \ + vst1q_s32(dst + 8, vreinterpretq_s32_s64(dst2)); \ + vst1q_s32(dst + 12, vreinterpretq_s32_s64(dst3)); \ + dst += 16; \ + } + +#define TRANS_AND_STORE_REMAIN(input0, input1, input2, input3, remain) \ + { \ + auto tmp01 = vzipq_s32(input0, input1); \ + auto tmp23 = vzipq_s32(input2, input3); \ + vdst[0] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \ + vreinterpretq_s64_s32(tmp23.val[0])); \ + vdst[1] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \ + vreinterpretq_s64_s32(tmp23.val[0])); \ + vdst[2] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \ + vreinterpretq_s64_s32(tmp23.val[1])); \ + vdst[3] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \ + vreinterpretq_s64_s32(tmp23.val[1])); \ + for (size_t i = 0; i < remain; i++) { \ + vst1q_s32(dst + i * 4, vreinterpretq_s32_s64(vdst[i])); \ + } \ + dst += 16; \ + } + +void optimize_fuse_im2col_packB(dt_int8* src, size_t ic, size_t iw, size_t ih, + size_t curr_iw, size_t curr_ih, dt_int8* dst_ptr) { + int* src_line0 = + reinterpret_cast(src + curr_ih * iw * 4 + curr_iw * 4); + int* src_line1 = + reinterpret_cast(src + (curr_ih + 1) * iw * 4 + curr_iw * 4); + int* src_line2 = + reinterpret_cast(src + (curr_ih + 2) * iw * 4 + curr_iw * 4); + int* dst = reinterpret_cast(dst_ptr); + int32x4_t input[12]; + int remain = 0; + for (size_t c = 0; c < ic; c++) { + input[remain] = vld1q_s32(src_line0); + input[remain + 1] = vld1q_s32(src_line0 + 1); + input[remain + 2] = vld1q_s32(src_line0 + 2); + input[remain + 3] = vld1q_s32(src_line1); + input[remain + 4] = vld1q_s32(src_line1 + 1); + input[remain + 5] = vld1q_s32(src_line1 + 2); + input[remain + 6] = vld1q_s32(src_line2); + input[remain + 7] = vld1q_s32(src_line2 + 1); + input[remain + 8] = vld1q_s32(src_line2 + 2); + TRANS_AND_STORE(input[0], input[1], input[2], input[3]); + TRANS_AND_STORE(input[4], input[5], input[6], input[7]); + if (remain == 3) { + TRANS_AND_STORE(input[8], input[9], input[10], input[11]); + remain = 0; + } else { + for (int i = 0; i <= remain; i++) { + input[i] = input[8 + i]; + } + remain++; + } + src_line0 += ih * iw; + src_line1 += ih * iw; + src_line2 += ih * iw; + } + //! pad remain to 4 + if (remain > 0) { + TRANS_AND_STORE(input[0], input[1], input[2], input[3]); + } +} + +void naive_fuse_im2col_packB(dt_int8* src, size_t ic, size_t iw, size_t ih, + size_t curr_iw, size_t curr_ih, size_t num_point, + size_t ow, dt_int8* dst_ptr) { + megdnn_assert(num_point <= 4_z, + "fuse im2col and packB of 4x4x16 num_point must less than 4"); + int* src_line0 = reinterpret_cast(src + curr_ih * iw * 4); + int* src_line1 = reinterpret_cast(src + (curr_ih + 1) * iw * 4); + int* src_line2 = reinterpret_cast(src + (curr_ih + 2) * iw * 4); + int remain = 0; + int out[9][4] = {{0}}; + int32x4_t input[12]; + int* dst = reinterpret_cast(dst_ptr); + for (size_t c = 0; c < ic; c++) { + //! Read int buffer out + size_t index = 0, w = curr_iw, dalta_h = 0; + while (index < num_point) { + int* src_next_line0 = src_line0 + dalta_h * iw; + int* src_next_line1 = src_next_line0 + iw; + int* src_next_line2 = src_next_line1 + iw; + for (; index < num_point && w < ow; index++, w++) { + out[0][index] = src_next_line0[w]; + out[1][index] = src_next_line0[w + 1]; + out[2][index] = src_next_line0[w + 2]; + out[3][index] = src_next_line1[w]; + out[4][index] = src_next_line1[w + 1]; + out[5][index] = src_next_line1[w + 2]; + out[6][index] = src_next_line2[w]; + out[7][index] = src_next_line2[w + 1]; + out[8][index] = src_next_line2[w + 2]; + } + //! next line + w = 0; + dalta_h += 1; + } + //! load int vector + input[remain] = vld1q_s32(out[0]); + input[remain + 1] = vld1q_s32(out[1]); + input[remain + 2] = vld1q_s32(out[2]); + input[remain + 3] = vld1q_s32(out[3]); + input[remain + 4] = vld1q_s32(out[4]); + input[remain + 5] = vld1q_s32(out[5]); + input[remain + 6] = vld1q_s32(out[6]); + input[remain + 7] = vld1q_s32(out[7]); + input[remain + 8] = vld1q_s32(out[8]); + int64x2_t vdst[4]; + TRANS_AND_STORE_REMAIN(input[0], input[1], input[2], input[3], num_point); + TRANS_AND_STORE_REMAIN(input[4], input[5], input[6], input[7], num_point); + if (remain == 3) { + TRANS_AND_STORE_REMAIN(input[8], input[9], input[10], input[11], + num_point); + remain = 0; + } else { + for (int i = 0; i <= remain; i++) { + input[i] = input[8 + i]; + } + remain++; + } + src_line0 += ih * iw; + src_line1 += ih * iw; + src_line2 += ih * iw; + } + //! pad remain to 4 + if (remain > 0) { + int64x2_t vdst[4]; + TRANS_AND_STORE_REMAIN(input[0], input[1], input[2], input[3], + num_point); + } +} +} // namespace + +template +void StrategyFuse4x4x16Nchw44:: + exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam, + fallback::MatrixMulImpl::AlgoBase*) { + size_t ow = param.osz[1]; + size_t ic = param.filter_meta.icpg; + size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; + size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; + constexpr static size_t pack_size = 4; + size_t input_offset = + ih * iw * ic * + (sparam.group_id + param.filter_meta.group * sparam.batch_id) * + sizeof(dt_int8); + + dt_int8* src2 = reinterpret_cast( + reinterpret_cast(bundle.get(BUNDLE_PADDING_INDEX)) + + input_offset); + bool is_phpwzero = param.filter_meta.padding[0] == 0 && + param.filter_meta.padding[1] == 0; + if (is_phpwzero) { + src2 = const_cast( + param.src(sparam.batch_id, sparam.group_id)); + } + dt_int8* b_panel = + reinterpret_cast(reinterpret_cast( + bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); + megdnn_assert(ic % 4 == 0, "nchw44 with ic is not of time 4"); + const int packed_k = (ic * 3 * 3) / pack_size; + const int ksize4 = round_up(packed_k, 4) * 16 * sizeof(dt_int8); + size_t out_size = sparam.output_block_size; + size_t curr_index = sparam.ohw_cur_index; + size_t curr_ih = curr_index / ow; + size_t curr_iw = curr_index % ow; + size_t out_index = 0; + while (out_index < out_size) { + for (; curr_iw + 3 < ow && out_index + 3 < out_size; + curr_iw += 4, out_index += 4) { + dt_int8* dst = b_panel + (out_index / 4) * ksize4; + optimize_fuse_im2col_packB(src2, ic / 4, iw, ih, curr_iw, curr_ih, + dst); + } + if (curr_iw < ow && out_index < out_size) { + size_t out_remain = std::min(out_size - out_index, 4_z); + size_t remain_point_this_line = std::min(ow - curr_iw, out_remain); + size_t start_point_next_line = + (out_remain - remain_point_this_line) % ow; + size_t pass_lines = (out_remain - remain_point_this_line) / ow; + dt_int8* dst = b_panel + (out_index / 4) * ksize4; + naive_fuse_im2col_packB(src2, ic / 4, iw, ih, curr_iw, curr_ih, + out_remain, ow, dst); + out_index += out_remain; + curr_iw = start_point_next_line; + curr_ih += (pass_lines + 1); + } else { + curr_iw = 0; + curr_ih++; + } + } +} +#undef TRANS_AND_STORE_REMAIN +#undef TRANS_AND_STORE + + + +namespace megdnn { + +template class StrategyFuse4x4x16Nchw44; +} // namespace megdnn + +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp index c587f3ca5..ab5ad5fcf 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp @@ -11,5 +11,209 @@ #include "src/fallback/conv_bias/im2col/strategy_base.h" +#if MEGDNN_AARCH64 +#include + +using namespace megdnn; + +namespace { + +#define PACKB_ONELINE() \ + int out_index = 0; \ + outptr = output_base; \ + for (; out_index + 11 < block_size; out_index += 12) { \ + std::memcpy(outptr, tmp_output, 48); \ + outptr += ksize12; \ + tmp_output += 12; \ + } \ + \ + outptr = output_base4; \ + for (; out_index + 3 < block_size; out_index += 4) { \ + std::memcpy(outptr, tmp_output, 16); \ + outptr += ksize4; \ + tmp_output += 4; \ + } \ + \ + if (out_index < block_size) { \ + uint32_t zerobuffer[4] = {0}; \ + size_t out_remain = std::min(block_size - out_index, 4); \ + std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \ + outptr += out_remain; \ + std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \ + } \ + output_base += 12; \ + output_base4 += 4; + +#define STOR_IM2COL_DST() \ + output0[count] = uint32_src[index + 0]; \ + output1[count] = uint32_src[index + 1]; \ + output2[count] = uint32_src[index + 2]; + +#define LOAD_AND_STOR_IM2COL_DST() \ + uint32x4_t v_tmp = vld1q_u32(&uint32_src[index + 4]); \ + uint32x4_t v_o1 = vextq_u32(v_o0, v_tmp, 1); \ + uint32x4_t v_o2 = vextq_u32(v_o0, v_tmp, 2); \ + vst1q_u32(&output0[count], v_o0); \ + vst1q_u32(&output1[count], v_o1); \ + vst1q_u32(&output2[count], v_o2); \ + v_o0 = v_tmp; + +void fuse_packb(const dt_int8* __restrict src, dt_int8* __restrict dst, + dt_int8* __restrict b_panel, const int OW, const int IC, + const int IH, const int IW, + const int cur_index, const int block_size) { + int start_h = cur_index / OW; + int cur_remain_w = cur_index % OW; + int end_h = (cur_index + block_size) / OW; + int end_remain_w = (cur_index + block_size) % OW; + bool same_line = start_h == end_h ? true : false; + size_t newIC = IC / 4; + const uint32_t* uint32_src = + static_cast(static_cast(src)); + uint32_t* output = static_cast(static_cast(dst)); + uint32_t* b_output = static_cast(static_cast(b_panel)); + const int packed_k = newIC * 3 * 3; + const int ksize12 = packed_k * 12 * sizeof(dt_int8); + const int ksize4 = packed_k * 4 * sizeof(dt_int8); + uint32_t* outptr = b_output; + uint32_t* output_base = b_output; + uint32_t* output_base4 = b_output + block_size / 12 * ksize12; + constexpr int FH = 3; + if (same_line) { + rep(ic, newIC) { + rep(fh, FH) { + size_t count = 0; + size_t index = 0; + int w = cur_remain_w; + index = (ic * IH + (start_h + fh)) * IW + w; + for (; w + 3 < end_remain_w; w += 4) { + vst1q_u32(&output[count], vld1q_u32(&uint32_src[index])); + count += 4; + index += 4; + } + for (; w < end_remain_w; w++) { + output[count++] = uint32_src[index++]; + } + output[count++] = uint32_src[index]; + output[count++] = uint32_src[index + 1]; + for (int i = 0; i < 3; i++) { + const uint32_t* tmp_output = output + i; + PACKB_ONELINE(); + } + } + } + } else { + rep(ic, newIC) { + rep(fh, FH) { + size_t count = 0; + size_t index = 0; + uint32_t* output0 = output; + uint32_t* output1 = output + block_size; + uint32_t* output2 = output1 + block_size; + int w = cur_remain_w; + index = (ic * IH + (start_h + fh)) * IW + w; + uint32x4_t v_o0 = vld1q_u32(&uint32_src[index]); + for ( ; w + 3 < OW; w += 4) { + LOAD_AND_STOR_IM2COL_DST(); + count += 4; + index += 4; + } + + for (; w < OW; w++) { + STOR_IM2COL_DST(); + count++; + index++; + } + + for (int h = start_h + 1; h < end_h; h++) { + int ow = 0; + index = (ic * IH + (h + fh)) * IW + ow; + v_o0 = vld1q_u32(&uint32_src[index]); + for (; ow + 3 < OW; ow += 4) { + LOAD_AND_STOR_IM2COL_DST(); + count += 4; + index += 4; + } + + for (; ow < OW; ow++) { + STOR_IM2COL_DST(); + count++; + index++; + } + } + + index = (ic * IH + (end_h + fh)) * IW; + w = 0; + v_o0 = vld1q_u32(&uint32_src[index]); + for ( ; w + 3 < end_remain_w; w+=4) { + LOAD_AND_STOR_IM2COL_DST(); + count+=4; + index+=4; + } + for ( ; w < end_remain_w; w++) { + STOR_IM2COL_DST(); + count++; + index++; + } + + for (int k = 0; k < 3; k++) { + const uint32_t* tmp_output = output + k * block_size; + PACKB_ONELINE(); + } + } + } + } +} +#undef PACKB_ONELINE +#undef STOR_IM2COL_DST +#undef LOAD_AND_STOR_IM2COL_DST +} // namespace + +template +void StrategyFuse8x12x4Nchw44Dot:: + exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam /*matmul_param*/, + fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { + size_t ow = param.osz[1]; + size_t ic = param.filter_meta.icpg; + size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; + size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; + size_t input_offset = + ih * iw * ic * + (sparam.group_id + param.filter_meta.group * sparam.batch_id) * + sizeof(dt_int8); + + dt_int8* src2 = reinterpret_cast( + reinterpret_cast(bundle.get(BUNDLE_PADDING_INDEX)) + + input_offset); + bool is_phpwzero = param.filter_meta.padding[0] == 0 && + param.filter_meta.padding[1] == 0; + if (is_phpwzero) { + src2 = const_cast( + param.src(sparam.batch_id, sparam.group_id)); + } + dt_int8* b_panel = + reinterpret_cast(reinterpret_cast( + bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); + megdnn_assert(ic % 4 == 0, "nchw44_dot with ic is not of time 4"); + + int8_t* im2col_dst = static_cast( + bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); + + fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, + sparam.output_block_size); +} + + +namespace megdnn { + +template class StrategyFuse8x12x4Nchw44Dot; +} // namespace megdnn + +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp new file mode 100644 index 000000000..328cb2d11 --- /dev/null +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp @@ -0,0 +1,209 @@ +/** + * \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/fallback/conv_bias/im2col/strategy_base.h" +#include "src/fallback/convolution/img2col_helper.h" + +#if MEGDNN_AARCH64 +#include + +using namespace megdnn; + +namespace { + +#define PACKB_ONELINE() \ + int out_index = 0; \ + outptr = output_base; \ + for (; out_index + 11 < block_size; out_index += 12) { \ + float32x4x4_t v0 = vld4q_f32(tmp_output); \ + float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \ + float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \ + vst1q_f32(outptr, v0.val[0]); \ + vst1q_f32(outptr + 4, v1.val[0]); \ + vst1q_f32(outptr + 8, v2.val[0]); \ + vst1q_f32(outptr + 12, v0.val[1]); \ + vst1q_f32(outptr + 16, v1.val[1]); \ + vst1q_f32(outptr + 20, v2.val[1]); \ + vst1q_f32(outptr + 24, v0.val[2]); \ + vst1q_f32(outptr + 28, v1.val[2]); \ + vst1q_f32(outptr + 32, v2.val[2]); \ + vst1q_f32(outptr + 36, v0.val[3]); \ + vst1q_f32(outptr + 40, v1.val[3]); \ + vst1q_f32(outptr + 44, v2.val[3]); \ + outptr += ksize12; \ + tmp_output += 48; \ + } \ + \ + outptr = output_base4; \ + for (; out_index + 3 < block_size; out_index += 4) { \ + float32x4x4_t v0 = vld4q_f32(tmp_output); \ + vst1q_f32(outptr, v0.val[0]); \ + vst1q_f32(outptr + 4, v0.val[1]); \ + vst1q_f32(outptr + 8, v0.val[2]); \ + vst1q_f32(outptr + 12, v0.val[3]); \ + outptr += ksize4; \ + tmp_output += 16; \ + } \ + \ + if (out_index < block_size) { \ + float zerobuffer[16] = {0}; \ + size_t out_remain = std::min(block_size - out_index, 4); \ + std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \ + float32x4x4_t v0 = vld4q_f32(zerobuffer); \ + vst1q_f32(outptr, v0.val[0]); \ + vst1q_f32(outptr + 4, v0.val[1]); \ + vst1q_f32(outptr + 8, v0.val[2]); \ + vst1q_f32(outptr + 12, v0.val[3]); \ + } \ + output_base += 48; \ + output_base4 += 16; + +#define LOAD_AND_STOR_IM2COL_DST() \ + float32x4_t v1 = vld1q_f32(&src[index + 4]); \ + float32x4_t v2 = vld1q_f32(&src[index + 8]); \ + vst1q_f32(&output0[i], v0); \ + vst1q_f32(&output1[i], v1); \ + vst1q_f32(&output2[i], v2); \ + i += 4; \ + index += 8; \ + v0 = v2; + +void fuse_packb(const float* __restrict src, float* __restrict dst, + float* __restrict b_panel, const int OW, const int IC, + const int IH, const int IW, const int cur_index, + const int block_size) { + int start_h = cur_index / OW; + int cur_remain_w = cur_index % OW; + int end_h = (cur_index + block_size) / OW; + int end_remain_w = (cur_index + block_size) % OW; + bool same_line = start_h == end_h ? true : false; + size_t newIC = IC / 4; + float* b_output = b_panel; + const int packed_k = IC * 3 * 3; + const int ksize12 = packed_k * 12; + const int ksize4 = packed_k * 4; + float* outptr = b_output; + float* output_base = b_output; + float* output_base4 = b_output + block_size / 12 * ksize12; + constexpr int FH = 3; + constexpr int SH = 2; + constexpr int SW = 2; + if (same_line) { + rep(ic, newIC) { + rep(fh, FH) { + float* output02 = dst; + float* output1 = dst + block_size * 4 + 4; + size_t i = 0; + + size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + + cur_remain_w * SW); + for (int w = cur_remain_w; w < end_remain_w; w++) { + vst1q_f32(&output02[i], vld1q_f32(&src[index])); + vst1q_f32(&output1[i], vld1q_f32(&src[index + 4])); + i += 4; + index += 8; + } + vst1q_f32(&output02[i], vld1q_f32(&src[index])); + float* output[3]; + output[0] = output02; + output[1] = output1; + output[2] = output02 + 4; + for (int i = 0; i < 3; i++) { + const float* tmp_output = output[i]; + PACKB_ONELINE(); + } + } + } + } else { + rep(ic, newIC) { + rep(fh, FH) { + float* output0 = dst; + float* output1 = dst + block_size * 4; + float* output2 = output1 + block_size * 4; + size_t i = 0; + + size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + + (cur_remain_w * SW)); + float32x4_t v0 = vld1q_f32(&src[index]); + for (int w = cur_remain_w; w < OW; w++) { + LOAD_AND_STOR_IM2COL_DST(); + } + + for (int h = start_h + 1; h < end_h; h++) { + size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW); + v0 = vld1q_f32(&src[index]); + rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); } + } + + index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW); + v0 = vld1q_f32(&src[index]); + for (int w = 0; w < end_remain_w; w++) { + LOAD_AND_STOR_IM2COL_DST(); + } + + for (int i = 0; i < 3; i++) { + const float* tmp_output = output0 + i * block_size * 4; + PACKB_ONELINE(); + } + } + } + } +} +#undef PACKB_ONELINE +#undef LOAD_AND_STOR_IM2COL_DST +} // namespace + +template +void StrategyFuse8x12x1Nchw44K3x3S2:: + exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam /*matmul_param*/, + fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { + size_t ow = param.osz[1]; + size_t ic = param.filter_meta.icpg; + size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; + size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; + size_t input_offset = + ih * iw * ic * + (sparam.group_id + param.filter_meta.group * sparam.batch_id) * + sizeof(float); + + float* src2 = reinterpret_cast( + reinterpret_cast(bundle.get(BUNDLE_PADDING_INDEX)) + + input_offset); + bool is_phpwzero = param.filter_meta.padding[0] == 0 && + param.filter_meta.padding[1] == 0; + if (is_phpwzero) { + src2 = const_cast( + param.src(sparam.batch_id, sparam.group_id)); + } + float* b_panel = reinterpret_cast(reinterpret_cast( + bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); + megdnn_assert(ic % 4 == 0, "nchw44_dot with ic is not of time 4"); + + float* im2col_dst = + static_cast(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); + + fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, + sparam.output_block_size); +} + +namespace megdnn { + +template class StrategyFuse8x12x1Nchw44K3x3S2; +} // namespace megdnn + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 5e606b1a0..ef19dea7e 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1838,7 +1838,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); #endif } - +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args( + {3}, 2, false, false, false, false, false, true, true,false); +#if MEGDNN_AARCH64 + check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); +#elif MEGDNN_ARMV7 + check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); +#endif +} /***************************** Conv1x1 Algo Test ***********************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { using namespace conv_bias; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index d2743bc1f..7681c9f6b 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -708,6 +708,66 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { + constexpr size_t RUNS = 40; + std::vector data_type = { + dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S, + bool is_nchw = false) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW44; + auto OH = (H + 2 * P - FS) / static_cast(S) + 1; + auto OW = (W + 2 * P - FS) / static_cast(S) + 1; + TensorShape src = {N, IC / 4, H, W, 4}; + TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4}; + if (group > 1) { + filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4}; + param.sparse = param::ConvBias::Sparse::GROUP; + } + if (is_nchw) { + src = {N, IC, H, W}; + filter = {OC / 4, FS, FS, IC, 4}; + } + TensorShape bias = {1, OC / 4, 1, 1, 4}; + TensorShape dst = {N, OC / 4, OH, OW, 4}; + + SmallVector shapes{src, filter, bias, {}, dst}; + float computations = + (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + bench_case(1, 64, 64, 56, 56, 3, 1, 1, 2); + bench_case(1, 128, 128, 28, 28, 3, 1, 1, 2); + bench_case(1, 256, 256, 14, 14, 3, 1, 1, 2); + bench_case(1, 512, 512, 7, 7, 3, 1, 1, 2); + + bench_case(1, 64, 64, 56, 56, 3, 4, 1, 2); + bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); + bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); + bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); + + bench_case(1, 64, 64, 56*2, 56*2, 3, 4, 1, 2); + bench_case(1, 128, 128, 28*2, 28*2, 3, 4, 1, 2); + bench_case(1, 256, 256, 14*2, 14*2, 3, 4, 1, 2); + bench_case(1, 512, 512, 7*2, 7*2, 3, 4, 1, 2); +} + + + + TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { constexpr size_t RUNS = 50; -- GitLab