From 07dd6b6c4e22dc4ad211f26202ea799bf6eedf8f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 25 Apr 2020 22:58:35 +0800 Subject: [PATCH] feat(dnn/arm): add arm nchw44 filter5x5 strdie1 and stride2 max pooling GitOrigin-RevId: f84c32f9222ea91ba130d4168a0c1c4a8c8b54ed --- dnn/src/arm_common/pooling/algo.cpp | 70 ++++++ dnn/src/arm_common/pooling/algo.h | 8 + .../pooling/do_max_pooling_5x5_nchw44.cpp | 202 ++++++++++++++++++ .../pooling/do_max_pooling_5x5_nchw44.h | 30 +++ dnn/src/arm_common/pooling/opr_impl.cpp | 2 + dnn/src/arm_common/pooling/opr_impl.h | 1 + dnn/test/arm_common/pooling.cpp | 50 +++++ dnn/test/arm_common/pooling_multi_thread.cpp | 50 +++++ 8 files changed, 413 insertions(+) create mode 100644 dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp create mode 100644 dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index fe36629cb..4fcd32d5d 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -13,6 +13,7 @@ #include "megdnn/opr_param_defs.h" #include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" +#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" @@ -806,6 +807,75 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( #undef DISPATCH_FUNC } +bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.format == Param::Format::NCHW44 && + param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && + (SW == 1 || SW == 2) && PH == 0 && PW == 0; + return avaible; +} + +void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + auto SW = param.stride[0]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(type, func, midout_type_id, i) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ + midout_iv(midout_type_id)) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + size_t index, size_t thread_id) { \ + MEGDNN_MARK_USED_VAR(thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + +#define DISPATCH_STRIDE(type, func, midout_type_id) \ + switch (SW) { \ + case 1: { \ + DISPATCH_FUNC(type, func, midout_type_id, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_FUNC(type, func, midout_type_id, 2); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport stride size"); \ + } + + DISPATCH_STRIDE(int8_t, int8, 12); + +#undef DISPATCH_STRIDE +#undef DISPATCH_FUNC +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index 3b7424856..8b1e78a57 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -115,6 +115,14 @@ public: void exec(const PoolingKernParam& param) const override; }; +class PoolingImpl::AlgoFilter5MaxStridexNCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER5_MAX_STRIDEX_NCHW44"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); } // namespace arm_common diff --git a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp b/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp new file mode 100644 index 000000000..a15622749 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp @@ -0,0 +1,202 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW) { + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, + max_tmp3, max_tmp4; + int32x4_t src1234, src2345, src3456; + +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src1234 = vextq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04), 1); \ + src2345 = vextq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04), 2); \ + src3456 = vextq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04), 3); \ + max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, src04); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + max_out = vmaxq_s8(max_tmp0, max_tmp1); + max_out = vmaxq_s8(max_out, max_tmp2); + max_out = vmaxq_s8(max_out, max_tmp3); + max_out = vmaxq_s8(max_out, max_tmp4); + + vst1q_s8(dptr, max_out); + + sptr0 += 16; + sptr1 += 16; + sptr2 += 16; + sptr3 += 16; + sptr4 += 16; + dptr += 16; +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src01, src23, max_out; + +#define CACULATE_ROW(i) \ + src01 = vld1_s8(sptr##i); \ + src23 = vld1_s8(sptr##i + 8); \ + int8x8_t max_tmp##i = vmax_s8(src01, src23); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + + max_out = vmax_s8(max_tmp0, max_tmp1); + max_out = vmax_s8(max_out, max_tmp2); + max_out = vmax_s8(max_out, max_tmp3); + max_out = vmax_s8(max_out, max_tmp4); + +#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); + UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) + int8x8_t max_45 = vmax_s8(src0_45, src1_45); + max_45 = vmax_s8(max_45, src1_45); + max_45 = vmax_s8(max_45, src2_45); + max_45 = vmax_s8(max_45, src3_45); + max_45 = vmax_s8(max_45, src4_45); + +#define store(i) \ + *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef COMPARE_SRC45 +#undef CACULATE_ROW + sptr0 += 4; + sptr1 += 4; + sptr2 += 4; + sptr3 += 4; + sptr4 += 4; + dptr += 4; + } + } +} +void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW) { + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src00, src04, src08, src09, src10, max_tmp0, max_tmp1, + max_tmp2, max_tmp3, max_tmp4; + int32x4_t src0246, src1357, src2468, src3579, src46810; + int32x4x2_t src_tmp; +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_s8(sptr##i + 4 * 8); \ + src09 = vld1q_s8(sptr##i + 4 * 9); \ + src10 = vld1q_s8(sptr##i + 4 * 10); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ + src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ + src46810 = vextq_s32(src2468, vreinterpretq_s32_s8(src10), 1); \ + max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ + vreinterpretq_s8_s32(src1357)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + + int8x16_t max_out = vmaxq_s8(max_tmp0, max_tmp1); + max_out = vmaxq_s8(max_out, max_tmp2); + max_out = vmaxq_s8(max_out, max_tmp3); + max_out = vmaxq_s8(max_out, max_tmp4); + + vst1q_s8(dptr, max_out); + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + sptr3 += 32; + sptr4 += 32; + dptr += 16; +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src01, src23, max_out; + +#define CACULATE_ROW(i) \ + src01 = vld1_s8(sptr##i); \ + src23 = vld1_s8(sptr##i + 8); \ + int8x8_t max_tmp##i = vmax_s8(src01, src23); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + + max_out = vmax_s8(max_tmp0, max_tmp1); + max_out = vmax_s8(max_out, max_tmp2); + max_out = vmax_s8(max_out, max_tmp3); + max_out = vmax_s8(max_out, max_tmp4); + +#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); + UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) + int8x8_t max_45 = vmax_s8(src0_45, src1_45); + max_45 = vmax_s8(max_45, src1_45); + max_45 = vmax_s8(max_45, src2_45); + max_45 = vmax_s8(max_45, src3_45); + max_45 = vmax_s8(max_45, src4_45); + +#define store(i) \ + *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef COMPARE_SRC45 +#undef CACULATE_ROW + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + sptr3 += 8; + sptr4 += 8; + dptr += 4; + } + } +} + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h b/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h new file mode 100644 index 000000000..3c9c12566 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h @@ -0,0 +1,30 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +#define KERN(strdie) \ + void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ + size_t OW, size_t PH, size_t PW); + +KERN(stride1) +KERN(stride2) + +#undef KERN +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index c0e096bc2..04b930274 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -29,6 +29,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; + AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; public: AlgoPack() { @@ -44,6 +45,7 @@ public: all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); + all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 67c63b652..9e716a467 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -87,6 +87,7 @@ private: class AlgoFilter3MaxStride2NCHW44; class AlgoFilter3MaxStride1NCHW44; class AlgoFilter4MaxStridexNCHW44; + class AlgoFilter5MaxStridexNCHW44; class AlgoPack; }; } // namespace arm_common diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index 6808ee448..557850352 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -254,6 +254,56 @@ TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S2x2_NCHW44) } // clang-format on } +TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S1x1_NCHW44) +{ + // clang-format off + for (size_t ih: {5, 9, 19, 20, 39}) + for (size_t iw: {5, 12, 23, 27, 39}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 5 && iw+2*pw >= 5) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 1; + param.window_h = param.window_w = 5; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} +TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S2x2_NCHW44) +{ + // clang-format off + for (size_t ih: {5, 9, 19, 20, 39}) + for (size_t iw: {5, 12, 23, 27, 39}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 5 && iw+2*pw >= 5) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 5; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON, POOLING_FP16) { diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index b3d04147c..4b34fb984 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -204,6 +204,56 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S2x2_NCHW44) } // clang-format on } +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) +{ + // clang-format off + for (size_t ih: {5, 9, 19, 20, 39}) + for (size_t iw: {5, 12, 23, 27, 39}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 5 && iw+2*pw >= 5) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 1; + param.window_h = param.window_w = 5; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S2x2_NCHW44) +{ + // clang-format off + for (size_t ih: {5, 9, 19, 20, 39}) + for (size_t iw: {5, 12, 23, 27, 39}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 5 && iw+2*pw >= 5) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 5; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) { -- GitLab