From 2ae9fdef401474feb43ce905306e7055dac8a921 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 9 May 2020 17:48:47 +0800 Subject: [PATCH] feat(dnn/arm): add arm common nchw44 avg pooling GitOrigin-RevId: 25eab33e14ef4000480acea17a49e4360d42bdd7 --- dnn/src/arm_common/pooling/algo.cpp | 221 +++++++++----- dnn/src/arm_common/pooling/algo.h | 18 +- .../pooling/do_pooling_2x2_nchw44.cpp | 206 ++++++++++++- .../pooling/do_pooling_2x2_nchw44.h | 21 +- .../pooling/do_pooling_3x3_nchw44.cpp | 241 ++++++++++++++- .../pooling/do_pooling_3x3_nchw44.h | 21 +- .../pooling/do_pooling_4x4_nchw44.cpp | 250 ++++++++++++++- .../pooling/do_pooling_4x4_nchw44.h | 15 +- .../pooling/do_pooling_5x5_nchw44.cpp | 284 +++++++++++++++++- .../pooling/do_pooling_5x5_nchw44.h | 15 +- dnn/src/arm_common/pooling/opr_impl.cpp | 16 +- dnn/src/arm_common/pooling/opr_impl.h | 8 +- dnn/test/arm_common/pooling_multi_thread.cpp | 77 +++-- 13 files changed, 1214 insertions(+), 179 deletions(-) diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 6ba09208b..9ae81eeae 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -73,14 +73,15 @@ WorkspaceBundle get_bundle_nchw44( const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH, size_t PW, - const WorkspaceBundle& ws) { + const WorkspaceBundle& ws, bool is_max_mode) { int8_t* sptr_base = nullptr; + int8_t padding_value = is_max_mode ? INT8_MIN : 0; bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; if (need_pad) { IH2 = IH + 2 * PH; IW2 = IW + 2 * PW; sptr_base = static_cast(ws.get(0)); - memset(sptr_base, -128, sizeof(int8_t) * IH2 * IW2 * 4); + memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4); rep(ih, IH) { std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, src + ih * IW * 4, sizeof(int8_t) * IW * 4); @@ -597,7 +598,7 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( MIDOUT_END(); } -bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( +bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; auto SW = param.stride[1]; @@ -606,12 +607,12 @@ bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 3 && FW == 3 && SW == SH && - (SH == 1 || SW == 2); + (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && + FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); return avaible; } -void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( +void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; @@ -623,8 +624,8 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(8), \ midout_iv(#type #i##_hash)) { \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \ auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ @@ -635,7 +636,7 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ - do_max_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ + do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ static_cast(src_ptr) + n * C * IH * IW * 4 + \ c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ @@ -648,27 +649,43 @@ void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ + .c_str()); \ + } + +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_MODE(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_MODE(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ } DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE +#undef DISPATCH_MODE #undef DISPATCH_FUNC } -bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( +bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; auto SW = param.stride[1]; @@ -677,12 +694,12 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && - (SW == 1 || SW == 2); + (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && + FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); return avaible; } -void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( +void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; @@ -694,8 +711,8 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(9), \ midout_iv(#func #i##_hash)) { \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \ auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ @@ -706,7 +723,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ - do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ + do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ static_cast(src_ptr) + n * C * IH * IW * 4 + \ c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ @@ -719,27 +736,43 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ + .c_str()); \ + } + +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_MODE(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_MODE(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ } DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE +#undef DISPATCH_MODE #undef DISPATCH_FUNC } -bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( +bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; auto SW = param.stride[1]; @@ -748,12 +781,12 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && - (SW == 1 || SW == 2); + (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && + FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2); return avaible; } -void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( +void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; @@ -765,8 +798,8 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(10), \ midout_iv(#func #i##_hash)) { \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \ auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ @@ -777,7 +810,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ - do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ + do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ static_cast(src_ptr) + n * C * IH * IW * 4 + \ c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ @@ -790,27 +823,43 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ + .c_str()); \ + } + +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_MODE(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_MODE(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ } DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE +#undef DISPATCH_MODE #undef DISPATCH_FUNC } -bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( +bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; auto SW = param.stride[1]; @@ -819,12 +868,12 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && - (SW == 1 || SW == 2); + (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && + FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); return avaible; } -void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( +void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; @@ -836,8 +885,8 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(11), \ midout_iv(#func #i##_hash)) { \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \ auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ @@ -848,7 +897,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ - do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ + do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ static_cast(src_ptr) + n * C * IH * IW * 4 + \ c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ @@ -861,23 +910,39 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport pooling mode %d", param.mode) \ + .c_str()); \ + } + +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_MODE(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_MODE(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_throw(ssprintf("Unsupport stride size %d", SW).c_str()); \ } DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE +#undef DISPATCH_MODE #undef DISPATCH_FUNC } diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index 1bae98bb4..2fde30355 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -83,34 +83,34 @@ public: void exec(const PoolingKernParam& param) const override; }; -class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase { +class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; } + const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; -class PoolingImpl::AlgoFilter2MaxStridexNCHW44 final : public AlgoBase { +class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"; } + const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; -class PoolingImpl::AlgoFilter4MaxStridexNCHW44 final : public AlgoBase { +class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER4_MAX_STRIDEX_NCHW44"; } + const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; -class PoolingImpl::AlgoFilter5MaxStridexNCHW44 final : public AlgoBase { +class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER5_MAX_STRIDEX_NCHW44"; } + const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; @@ -122,7 +122,7 @@ WorkspaceBundle get_bundle_nchw44( const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH, size_t PW, - const WorkspaceBundle& ws); + const WorkspaceBundle& ws, bool is_max_mode); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp index fa3db8853..6d38c0d12 100644 --- a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp + * \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -24,7 +24,7 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; @@ -70,7 +70,7 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; @@ -120,6 +120,206 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } +void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 4; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src0123, src1234; + int16x8_t src01, src23, src12, src34; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src0123 = vld1q_s8(sptr##i); \ + src1234 = vld1q_s8(sptr##i + 4); \ + src01 = vmovl_s8(vget_low_s8(src0123)); \ + src23 = vmovl_s8(vget_high_s8(src0123)); \ + src12 = vmovl_s8(vget_low_s8(src1234)); \ + src34 = vmovl_s8(vget_high_s8(src1234)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src12); \ + sum23 = vaddq_s16(sum23, src23); \ + sum23 = vaddq_s16(sum23, src34); + + UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; + +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 16; + sptr1 += 16; + dptr += 16; +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src001 = vld1_s8(sptr0); + int8x8_t src101 = vld1_s8(sptr1); + int16x8_t src00 = vmovl_s8(src001); + int16x8_t src10 = vmovl_s8(src101); + int16x8_t max_tmp = vaddq_s16(src00, src10); +#define do_acc(i) \ + int16_t sum##i = \ + vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef do_avg +#undef do_acc + sptr0 += 4; + sptr1 += 4; + dptr += 4; + } + } +} + +void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 4; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int32x4x2_t src_tmp; + int8x16_t src00, src04; + int32x4_t src0246, src1357; + int16x8_t src02, src13, src46, src57; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); + + UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 32; + sptr1 += 32; + dptr += 16; +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src001 = vld1_s8(sptr0); + int8x8_t src101 = vld1_s8(sptr1); + int16x8_t src00 = vmovl_s8(src001); + int16x8_t src10 = vmovl_s8(src101); + int16x8_t max_tmp = vaddq_s16(src00, src10); + +#define do_acc(i) \ + int16_t sum##i = \ + vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) +#undef do_avg +#undef do_acc +#undef store + sptr0 += 8; + sptr1 += 8; + dptr += 4; + } + } +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h index e517406b3..a23e1019b 100644 --- a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h + * \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -15,16 +15,15 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws); -void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws); +#define KERN(mode, stride, ctype) \ + void do_##mode##_pooling_2x2_##stride##_##ctype##_nchw44_NEON( \ + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ + size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); +KERN(max, stride1, int8) +KERN(max, stride2, int8) +KERN(avg, stride1, int8) +KERN(avg, stride2, int8) +#undef KERN } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp index 3717cc5a7..be0092802 100644 --- a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp + * \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -24,7 +24,7 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; @@ -99,7 +99,7 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; @@ -190,6 +190,241 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } +void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 9; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src0123, src1234, src2345; + int16x8_t src01, src23, src12, src34, src45; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src0123 = vld1q_s8(sptr##i); \ + src1234 = vld1q_s8(sptr##i + 4); \ + src2345 = vld1q_s8(sptr##i + 8); \ + src01 = vmovl_s8(vget_low_s8(src0123)); \ + src23 = vmovl_s8(vget_high_s8(src0123)); \ + src12 = vmovl_s8(vget_low_s8(src1234)); \ + src34 = vmovl_s8(vget_high_s8(src1234)); \ + src45 = vmovl_s8(vget_high_s8(src2345)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src12); \ + sum01 = vaddq_s16(sum01, src23); \ + sum23 = vaddq_s16(sum23, src23); \ + sum23 = vaddq_s16(sum23, src34); \ + sum23 = vaddq_s16(sum23, src45); + + UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 16; + sptr1 += 16; + sptr2 += 16; + dptr += 16; +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src001, src012; + int16x8_t src01, src12, sum01, sum02; + sum01 = vdupq_n_s16(0); + sum02 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src001 = vld1_s8(sptr##i); \ + src012 = vld1_s8(sptr##i + 4); \ + src01 = vmovl_s8(src001); \ + src12 = vmovl_s8(src012); \ + sum01 = vaddq_s16(sum01, src01); \ + sum02 = vaddq_s16(sum02, src12); + + UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ + vgetq_lane_s16(sum02, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef do_avg +#undef do_acc +#undef CACULATE_ROW + sptr0 += 4; + sptr1 += 4; + sptr2 += 4; + dptr += 4; + } + } +} +void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 9; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int32x4x2_t src_tmp; + int8x16_t src00, src04; + int32x4_t src0246, src1357, src2468, src08; + int16x8_t src02, src46, src13, src57, src24, src68; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ + sum23 = vaddq_s16(sum23, src68); + + UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + dptr += 16; +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src001, src012; + int16x8_t src01, src12, sum01, sum02; + sum01 = vdupq_n_s16(0); + sum02 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src001 = vld1_s8(sptr##i); \ + src012 = vld1_s8(sptr##i + 4); \ + src01 = vmovl_s8(src001); \ + src12 = vmovl_s8(src012); \ + sum01 = vaddq_s16(sum01, src01); \ + sum02 = vaddq_s16(sum02, src12); + + UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ + vgetq_lane_s16(sum02, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef do_avg +#undef do_acc + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + dptr += 4; + } + } +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h index 09c1ceb9c..00cc4803c 100644 --- a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h + * \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -15,16 +15,15 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws); - -void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW, - const WorkspaceBundle& ws); +#define KERN(mode, stride, ctype) \ + void do_##mode##_pooling_3x3_##stride##_##ctype##_nchw44_NEON( \ + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ + size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); +KERN(max, stride1, int8) +KERN(max, stride2, int8) +KERN(avg, stride1, int8) +KERN(avg, stride2, int8) +#undef KERN } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp index e68af81be..cf4a514c0 100644 --- a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp @@ -24,7 +24,7 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; @@ -99,7 +99,7 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; @@ -171,6 +171,252 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } +void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 16; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int16x8_t src01, src23, src12, src34, src45, src56; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src01 = vmovl_s8(vld1_s8(sptr##i)); \ + src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ + src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \ + src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \ + src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \ + src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src12); \ + sum01 = vaddq_s16(sum01, src23); \ + sum01 = vaddq_s16(sum01, src34); \ + sum23 = vaddq_s16(sum23, src23); \ + sum23 = vaddq_s16(sum23, src34); \ + sum23 = vaddq_s16(sum23, src45); \ + sum23 = vaddq_s16(sum23, src56); + + UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 16; + sptr1 += 16; + sptr2 += 16; + sptr3 += 16; + dptr += 16; + +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int16x8_t src01, src23, sum01; + sum01 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src01 = vmovl_s8(vld1_s8(sptr##i)); \ + src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src23); + + UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) + +#undef store +#undef do_avg +#undef do_acc +#undef CACULATE_ROW + + sptr0 += 4; + sptr1 += 4; + sptr2 += 4; + sptr3 += 4; + dptr += 4; + } + } +} +void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 16; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int32x4x2_t src_tmp; + int8x16_t src00, src04; + int16x8_t src02, src13, src57, src24, src68, src35, src79, src46; + int32x4_t src08, src09, src0246, src1357, src2468, src3579; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ + src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum01 = vaddq_s16(sum01, src35); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ + sum23 = vaddq_s16(sum23, src68); \ + sum23 = vaddq_s16(sum23, src79); + + UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + sptr3 += 32; + dptr += 16; + +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int8x8_t src001, src023; + int16x8_t src01, src23, sum01; + sum01 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src001 = vld1_s8(sptr##i); \ + src023 = vld1_s8(sptr##i + 8); \ + src01 = vmovl_s8(src001); \ + src23 = vmovl_s8(src023); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src23); + + UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) + +#undef store +#undef do_avg +#undef do_acc +#undef CACULATE_ROW + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + sptr3 += 8; + dptr += 4; + } + } +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h index 166c99a6d..0e9741e9c 100644 --- a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h + * \file dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -15,15 +15,16 @@ namespace megdnn { namespace arm_common { -#define KERN(strdie) \ - void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ +#define KERN(mode, stride, ctype) \ + void do_##mode##_pooling_4x4_##stride##_##ctype##_nchw44_NEON( \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); - -KERN(stride1) -KERN(stride2) - +KERN(max, stride1, int8) +KERN(max, stride2, int8) +KERN(avg, stride1, int8) +KERN(avg, stride2, int8) #undef KERN + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp index e1d195191..a7e5c0f98 100644 --- a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp + * \file dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -24,7 +24,7 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; @@ -118,7 +118,7 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; - sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; @@ -213,6 +213,284 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } +void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 25; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int16x8_t src01, src23, src12, src34, src45, src56, src67; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src01 = vmovl_s8(vld1_s8(sptr##i)); \ + src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \ + src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \ + src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \ + src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \ + src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \ + src67 = vmovl_s8(vld1_s8(sptr##i + 24)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src12); \ + sum01 = vaddq_s16(sum01, src23); \ + sum01 = vaddq_s16(sum01, src34); \ + sum01 = vaddq_s16(sum01, src45); \ + sum23 = vaddq_s16(sum23, src23); \ + sum23 = vaddq_s16(sum23, src34); \ + sum23 = vaddq_s16(sum23, src45); \ + sum23 = vaddq_s16(sum23, src56); \ + sum23 = vaddq_s16(sum23, src67); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 16; + sptr1 += 16; + sptr2 += 16; + sptr3 += 16; + sptr4 += 16; + dptr += 16; + +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int32x2_t src004; + int8x8_t src001, src023; + int16x8_t src01, src23, src04, sum01, sum02; + sum01 = vdupq_n_s16(0); + sum02 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src001 = vld1_s8(sptr##i); \ + src023 = vld1_s8(sptr##i + 8); \ + src004 = vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); \ + src01 = vmovl_s8(src001); \ + src23 = vmovl_s8(src023); \ + src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src23); \ + sum02 = vaddq_s16(sum02, src04); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ + vgetq_lane_s16(sum02, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) + +#undef store +#undef do_avg +#undef do_acc +#undef CACULATE_ROW + sptr0 += 4; + sptr1 += 4; + sptr2 += 4; + sptr3 += 4; + sptr4 += 4; + dptr += 4; + } + } +} + +void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int16_t filter_size = 25; + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int32x4x2_t src_tmp; + int8x16_t src00, src04; + int16x8_t src02, src13, src57, src24, src68, src35, src79, src46, + src810; + int32x4_t src08, src09, src10, src0246, src1357, src2468, src3579, + src46810; + int16x8_t sum01 = vdupq_n_s16(0); + int16x8_t sum23 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src46810 = vextq_s32(src2468, src10, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ + src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ + src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \ + src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum01 = vaddq_s16(sum01, src35); \ + sum01 = vaddq_s16(sum01, src46); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ + sum23 = vaddq_s16(sum23, src68); \ + sum23 = vaddq_s16(sum23, src79); \ + sum23 = vaddq_s16(sum23, src810); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + +#define sum_define(i) int16_t sum##i; + UNROLL_CALL_NOWRAPPER(8, sum_define) + +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ + filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ + filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ + filter_size; +#define store_sum01(i) *(dptr + i) = static_cast(sum##i); +#define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(8, sum01_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum01) + + UNROLL_CALL_NOWRAPPER(8, sum23_avg) + UNROLL_CALL_NOWRAPPER(8, store_sum23) + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + sptr3 += 32; + sptr4 += 32; + dptr += 16; + +#undef store_sum01 +#undef store_sum23 +#undef sum01_avg +#undef sum23_avg +#undef sum_define +#undef CACULATE_ROW + } + for (; ow < OW; ++ow) { + int32x2_t src004; + int8x8_t src001, src023; + int16x8_t src01, src23, src04, sum01, sum02; + sum01 = vdupq_n_s16(0); + sum02 = vdupq_n_s16(0); + +#define CACULATE_ROW(i) \ + src001 = vld1_s8(sptr##i); \ + src023 = vld1_s8(sptr##i + 8); \ + src004 = vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); \ + src01 = vmovl_s8(src001); \ + src23 = vmovl_s8(src023); \ + src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \ + sum01 = vaddq_s16(sum01, src01); \ + sum01 = vaddq_s16(sum01, src23); \ + sum02 = vaddq_s16(sum02, src04); + + UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) + +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \ + vgetq_lane_s16(sum02, i + 4); +#define do_avg(i) \ + sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ + : (sum##i - filter_size / 2) / filter_size; +#define store(i) *(dptr + i) = static_cast(sum##i); + + UNROLL_CALL_NOWRAPPER(4, do_acc) + UNROLL_CALL_NOWRAPPER(4, do_avg) + UNROLL_CALL_NOWRAPPER(4, store) + +#undef store +#undef do_avg +#undef do_acc +#undef CACULATE_ROW + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + sptr3 += 8; + sptr4 += 8; + dptr += 4; + } + } +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h index 73221f1f4..11062e32b 100644 --- a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h + * \file dnn/src/arm_common/pooling/do__pooling_5x5_nchw44.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -15,15 +15,16 @@ namespace megdnn { namespace arm_common { -#define KERN(strdie) \ - void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ +#define KERN(mode, stride, ctype) \ + void do_##mode##_pooling_5x5_##stride##_##ctype##_nchw44_NEON( \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); - -KERN(stride1) -KERN(stride2) - +KERN(max, stride1, int8) +KERN(max, stride2, int8) +KERN(avg, stride1, int8) +KERN(avg, stride2, int8) #undef KERN + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index f7556ee15..f7165cd31 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -25,10 +25,10 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoFilter5MaxStride2 algo_filter5_max_stride2; AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; - AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; - AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4; - AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; - AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; + AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4; + AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4; + AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; + AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; public: AlgoPack() { @@ -40,10 +40,10 @@ public: all_algos.emplace_back(&algo_filter5_max_stride2); all_algos.emplace_back(&algo_int8_filter2_max_stride2); all_algos.emplace_back(&algo_int8_filter3_max_stride2); - all_algos.emplace_back(&algo_filter3_max_stridex_nchw4); - all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); - all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); - all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); + all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4); + all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4); + all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); + all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 2bb8f992f..92b6b4641 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -83,10 +83,10 @@ private: class AlgoFilter5MaxStride2; class AlgoInt8Filter2MaxStride2; class AlgoInt8Filter3MaxStride2; - class AlgoFilter2MaxStridexNCHW44; - class AlgoFilter3MaxStridexNCHW44; - class AlgoFilter4MaxStridexNCHW44; - class AlgoFilter5MaxStridexNCHW44; + class AlgoFilter2ModexStridexNCHW44; + class AlgoFilter3ModexStridexNCHW44; + class AlgoFilter4ModexStridexNCHW44; + class AlgoFilter5ModexStridexNCHW44; class AlgoPack; }; } // namespace arm_common diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index 1543ae108..3f3cfe2e4 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -10,6 +10,7 @@ */ #include #include "megdnn/dtype.h" +#include "megdnn/opr_param_defs.h" #include "test/arm_common/fixture.h" #include "test/common/pooling.h" @@ -56,13 +57,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { } } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) { // clang-format off for (size_t ih: {3, 5, 10}) for (size_t iw: {3, 5, 7, 9, 15, 20}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) + for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 3 && iw+2*pw >= 3) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -71,7 +73,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) checker.set_rng(0,&rng); param::Pooling param; - param.mode = param::Pooling::Mode::MAX; + param.mode = mode; param.format = param::Pooling::Format::NCHW44; param.pad_h = ph; param.pad_w = pw; @@ -86,13 +88,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44) { // clang-format off for (size_t ih: {2, 5, 10, 17}) for (size_t iw: {2, 6, 8, 16, 26}) for (size_t ph: {0, 1}) for (size_t pw: {0, 1}) + for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 2 && iw+2*pw >= 2) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -101,7 +104,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) checker.set_rng(0,&rng); param::Pooling param; - param.mode = param::Pooling::Mode::MAX; + param.mode = mode; param.format = param::Pooling::Format::NCHW44; param.pad_h = ph; param.pad_w = pw; @@ -115,13 +118,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44) { // clang-format off for (size_t ih: {4, 10, 18, 25, 30}) for (size_t iw: {4, 12, 17, 20, 25}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) + for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 4 && iw+2*pw >= 4) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -130,7 +134,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) checker.set_rng(0,&rng); param::Pooling param; - param.mode = param::Pooling::Mode::MAX; + param.mode = mode; param.format = param::Pooling::Format::NCHW44; param.pad_h = ph; param.pad_w = pw; @@ -143,13 +147,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) } // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44) { // clang-format off for (size_t ih: {5, 9, 19, 20, 39}) for (size_t iw: {5, 12, 23, 27, 39}) for (size_t ph: {0, 1, 2}) for (size_t pw: {0, 1, 2}) + for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE}) if (ih+2*ph >= 5 && iw+2*pw >= 5) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -158,7 +163,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) checker.set_rng(0,&rng); param::Pooling param; - param.mode = param::Pooling::Mode::MAX; + param.mode = mode; param.format = param::Pooling::Format::NCHW44; param.pad_h = ph; param.pad_w = pw; @@ -477,31 +482,37 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) { std::vector> shapes; std::vector> filter_and_stride = { {2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}}; - for(auto filter:filter_and_stride){ - - shapes.push_back({{1, 32 * 4, 215, 215}, {}}); - shapes.push_back({{1, 32 * 4, 128, 128}, {}}); - shapes.push_back({{1, 16 * 4, 56, 56}, {}}); - - param.window_h = param.window_w = filter[0]; - param.stride_h = param.stride_w = filter[1]; - param.format = Param::Format::NCHW; - printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, - param.window_h, param.stride_h, static_cast(param.mode)); - benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - dtype::QuantizedS8(1.1f)); - shapes.clear(); - shapes.push_back({{1, 32, 215, 215,4}, {}}); - shapes.push_back({{1, 32, 128, 128,4}, {}}); - shapes.push_back({{1, 16, 56, 56, 4}, {}}); - - param.format = Param::Format::NCHW44; - printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, - param.window_w, param.stride_h, static_cast(param.mode)); - benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - dtype::QuantizedS8(1.1f)); - shapes.clear(); - } + + for (auto mode : + {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) { + for (auto filter : filter_and_stride) { + shapes.push_back({{1, 32 * 4, 215, 215}, {}}); + shapes.push_back({{1, 32 * 4, 128, 128}, {}}); + shapes.push_back({{1, 16 * 4, 56, 56}, {}}); + + param.mode = mode; + param.window_h = param.window_w = filter[0]; + param.stride_h = param.stride_w = filter[1]; + param.format = Param::Format::NCHW; + printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", + param.window_h, param.window_h, param.stride_h, + static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, + {1, {4}}, dtype::QuantizedS8(1.1f)); + shapes.clear(); + shapes.push_back({{1, 32, 215, 215, 4}, {}}); + shapes.push_back({{1, 32, 128, 128, 4}, {}}); + shapes.push_back({{1, 16, 56, 56, 4}, {}}); + + param.format = Param::Format::NCHW44; + printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", + param.window_h, param.window_w, param.stride_h, + static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, + {1, {4}}, dtype::QuantizedS8(1.1f)); + shapes.clear(); + } + } } #endif -- GitLab