From df8931b6cd3ab1a4583f6e51700f38a42cc9c994 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 6 May 2020 14:18:46 +0800 Subject: [PATCH] feat(dnn/arm): add padding support for nchw44 arm pooling and opt code GitOrigin-RevId: f125004e1f271656f2c2646913aea6afdd112e15 --- dnn/src/arm_common/pooling/algo.cpp | 271 +++++++++--------- dnn/src/arm_common/pooling/algo.h | 18 +- .../do_max_pooling_3x3_s1x1_nchw44.cpp | 91 ------ .../pooling/do_max_pooling_3x3_s1x1_nchw44.h | 25 -- .../do_max_pooling_3x3_s2x2_nchw44.cpp | 112 -------- ...2_nchw44.cpp => do_pooling_2x2_nchw44.cpp} | 41 ++- ...g_2x2_nchw44.h => do_pooling_2x2_nchw44.h} | 6 +- .../pooling/do_pooling_3x3_nchw44.cpp | 195 +++++++++++++ ..._s2x2_nchw44.h => do_pooling_3x3_nchw44.h} | 11 +- ...4_nchw44.cpp => do_pooling_4x4_nchw44.cpp} | 64 +++-- ...g_4x4_nchw44.h => do_pooling_4x4_nchw44.h} | 2 +- ...5_nchw44.cpp => do_pooling_5x5_nchw44.cpp} | 106 ++++--- ...g_5x5_nchw44.h => do_pooling_5x5_nchw44.h} | 2 +- dnn/src/arm_common/pooling/opr_impl.cpp | 12 +- dnn/src/arm_common/pooling/opr_impl.h | 3 +- dnn/test/arm_common/pooling.cpp | 205 ------------- dnn/test/arm_common/pooling_multi_thread.cpp | 182 +++++------- 17 files changed, 552 insertions(+), 794 deletions(-) delete mode 100644 dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp delete mode 100644 dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h delete mode 100644 dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp rename dnn/src/arm_common/pooling/{do_max_pooling_2x2_nchw44.cpp => do_pooling_2x2_nchw44.cpp} (78%) rename dnn/src/arm_common/pooling/{do_max_pooling_2x2_nchw44.h => do_pooling_2x2_nchw44.h} (87%) create mode 100644 dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp rename dnn/src/arm_common/pooling/{do_max_pooling_3x3_s2x2_nchw44.h => do_pooling_3x3_nchw44.h} (58%) rename dnn/src/arm_common/pooling/{do_max_pooling_4x4_nchw44.cpp => do_pooling_4x4_nchw44.cpp} (77%) rename dnn/src/arm_common/pooling/{do_max_pooling_4x4_nchw44.h => do_pooling_4x4_nchw44.h} (92%) rename dnn/src/arm_common/pooling/{do_max_pooling_5x5_nchw44.cpp => do_pooling_5x5_nchw44.cpp} (67%) rename dnn/src/arm_common/pooling/{do_max_pooling_5x5_nchw44.h => do_pooling_5x5_nchw44.h} (92%) diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 4fcd32d5..6ba09208 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -11,14 +11,13 @@ */ #include "src/arm_common/pooling/algo.h" #include "megdnn/opr_param_defs.h" -#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" -#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" -#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" -#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" -#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" +#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h" +#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h" +#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h" +#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h" #include "midout.h" @@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { return ws; } +WorkspaceBundle get_bundle_nchw44( + const PoolingImpl::PoolingKernSizeParam& param) { + megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) && + (param.format == param::Pooling::Format::NCHW44)); + auto IH = param.isz[0]; + auto IW = param.isz[1]; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + size_t padding_size = 0; + if ((PH != 0) || (PW != 0)) { + padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t); + } + return WorkspaceBundle(nullptr, {padding_size}); +} + +const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, + size_t& IH2, size_t& IW2, size_t PH, size_t PW, + const WorkspaceBundle& ws) { + int8_t* sptr_base = nullptr; + bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; + if (need_pad) { + IH2 = IH + 2 * PH; + IW2 = IW + 2 * PW; + sptr_base = static_cast(ws.get(0)); + memset(sptr_base, -128, sizeof(int8_t) * IH2 * IW2 * 4); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, + src + ih * IW * 4, sizeof(int8_t) * IW * 4); + } + } else { + IH2 = IH; + IW2 = IW; + } + return need_pad ? sptr_base : src; +} bool PoolingImpl::AlgoFilterxModexStride1::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; @@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( MIDOUT_END(); } -bool PoolingImpl::AlgoFilter3MaxStride2NCHW44::usable( +bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; auto SW = param.stride[1]; auto FH = param.filter[0]; auto FW = param.filter[1]; - auto PH = param.padding[0]; - auto PW = param.padding[1]; bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && - SW == 2 && PH == 0 && PW == 0; + param.mode == Mode::MAX && FH == 3 && FW == 3 && SW == SH && + (SH == 1 || SW == 2); return avaible; } -void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( +void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec( const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; auto PH = param.padding[0]; auto PW = param.padding[1]; + auto SW = param.stride[0]; void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id) \ +#define DISPATCH_FUNC(type, func, i) \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + midout_iv(#type #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ size_t index, size_t thread_id) { \ - MEGDNN_MARK_USED_VAR(thread_id); \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ - do_max_pooling_3x3_s2x2_##func##_nchw44_NEON( \ + do_max_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ static_cast(src_ptr) + n * C * IH * IW * 4 + \ c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW); \ + IH, IW, OH, OW, PH, PW, ws); \ }; \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ @@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( } \ MIDOUT_END(); - DISPATCH_FUNC(int8_t, int8, 9); - -#undef DISPATCH_FUNC -} - -bool PoolingImpl::AlgoFilter3MaxStride1NCHW44::usable( - const PoolingKernSizeParam& param) const { - auto SH = param.stride[0]; - auto SW = param.stride[1]; - auto FH = param.filter[0]; - auto FW = param.filter[1]; - auto PH = param.padding[0]; - auto PW = param.padding[1]; - - bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.format == Param::Format::NCHW44 && - param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 1 && - SW == 1 && PH == 0 && PW == 0; - return avaible; -} - -void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec( - const PoolingKernParam& param) const { - auto IH = param.isz[0], IW = param.isz[1]; - auto OH = param.osz[0], OW = param.osz[1]; - auto N = param.n, C = param.ic; - auto PH = param.padding[0]; - auto PW = param.padding[1]; - - void* src_ptr = param.src_ptr; - void* dst_ptr = param.dst_ptr; - -#define DISPATCH_FUNC(type, func, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ - size_t index, size_t thread_id) { \ - MEGDNN_MARK_USED_VAR(thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_max_pooling_3x3_s1x1_##func##_nchw44_NEON( \ - static_cast(src_ptr) + n * C * IH * IW * 4 + \ - c * IH * IW * 4, \ - static_cast(dst_ptr) + n * C * OH * OW * 4 + \ - c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ - MIDOUT_END(); +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_FUNC(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_FUNC(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport stride size"); \ + } - DISPATCH_FUNC(int8_t, int8, 10); + DISPATCH_STRIDE(int8_t, int8); +#undef DISPATCH_STRIDE #undef DISPATCH_FUNC } @@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( auto SW = param.stride[1]; auto FH = param.filter[0]; auto FW = param.filter[1]; - auto PH = param.padding[0]; - auto PW = param.padding[1]; bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && - (SW == 1 || SW == 2) && PH == 0 && PW == 0; + (SW == 1 || SW == 2); return avaible; } @@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id, i) \ +#define DISPATCH_FUNC(type, func, i) \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ size_t index, size_t thread_id) { \ - MEGDNN_MARK_USED_VAR(thread_id); \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ @@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW); \ + IH, IW, OH, OW, PH, PW, ws); \ }; \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ @@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func, midout_type_id) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, midout_type_id, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, midout_type_id, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_FUNC(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_FUNC(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport stride size"); \ } - DISPATCH_STRIDE(int8_t, int8, 10); + DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE #undef DISPATCH_FUNC @@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( auto SW = param.stride[1]; auto FH = param.filter[0]; auto FW = param.filter[1]; - auto PH = param.padding[0]; - auto PW = param.padding[1]; bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && - (SW == 1 || SW == 2) && PH == 0 && PW == 0; + (SW == 1 || SW == 2); return avaible; } @@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id, i) \ +#define DISPATCH_FUNC(type, func, i) \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ size_t index, size_t thread_id) { \ - MEGDNN_MARK_USED_VAR(thread_id); \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ @@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW); \ + IH, IW, OH, OW, PH, PW, ws); \ }; \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ @@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func, midout_type_id) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, midout_type_id, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, midout_type_id, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_FUNC(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_FUNC(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport stride size"); \ } - DISPATCH_STRIDE(int8_t, int8, 11); + DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE #undef DISPATCH_FUNC @@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( auto SW = param.stride[1]; auto FH = param.filter[0]; auto FW = param.filter[1]; - auto PH = param.padding[0]; - auto PW = param.padding[1]; bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.format == Param::Format::NCHW44 && param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && - (SW == 1 || SW == 2) && PH == 0 && PW == 0; + (SW == 1 || SW == 2); return avaible; } @@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id, i) \ +#define DISPATCH_FUNC(type, func, i) \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ size_t index, size_t thread_id) { \ - MEGDNN_MARK_USED_VAR(thread_id); \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ size_t n = index / C; \ size_t c = index % C; \ do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ @@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( c * IH * IW * 4, \ static_cast(dst_ptr) + n * C * OH * OW * 4 + \ c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW); \ + IH, IW, OH, OW, PH, PW, ws); \ }; \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ @@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( } \ MIDOUT_END(); -#define DISPATCH_STRIDE(type, func, midout_type_id) \ - switch (SW) { \ - case 1: { \ - DISPATCH_FUNC(type, func, midout_type_id, 1); \ - break; \ - } \ - case 2: { \ - DISPATCH_FUNC(type, func, midout_type_id, 2); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport stride size"); \ +#define DISPATCH_STRIDE(type, func) \ + switch (SW) { \ + case 1: { \ + DISPATCH_FUNC(type, func, 1); \ + break; \ + } \ + case 2: { \ + DISPATCH_FUNC(type, func, 2); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport stride size"); \ } - DISPATCH_STRIDE(int8_t, int8, 12); + DISPATCH_STRIDE(int8_t, int8); #undef DISPATCH_STRIDE #undef DISPATCH_FUNC diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index 8b1e78a5..1bae98bb 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -83,18 +83,10 @@ public: void exec(const PoolingKernParam& param) const override; }; -class PoolingImpl::AlgoFilter3MaxStride2NCHW44 final : public AlgoBase { +class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase { public: bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE2_NCHW44"; } - bool usable(const PoolingKernSizeParam& param) const override; - void exec(const PoolingKernParam& param) const override; -}; - -class PoolingImpl::AlgoFilter3MaxStride1NCHW44 final : public AlgoBase { -public: - bool is_reproducible() const override { return true; } - const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE1_NCHW44"; } + const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; @@ -125,6 +117,12 @@ public: WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); +WorkspaceBundle get_bundle_nchw44( + const PoolingImpl::PoolingKernSizeParam& param); + +const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, + size_t& IH2, size_t& IW2, size_t PH, size_t PW, + const WorkspaceBundle& ws); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp deleted file mode 100644 index d3b0a48c..00000000 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/** - * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/unroll_macro.h" - -namespace megdnn { -namespace arm_common { - -void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW) { - size_t oh = 0; - for (; oh < OH; ++oh) { - size_t ih = oh; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - int8_t* __restrict dptr = dst + oh * OW * 4; - size_t ow = 0; - for (; ow + 3 < OW; ow += 4) { - int8x16_t src0123 = vld1q_s8(sptr0); - int8x16_t src1234 = vld1q_s8(sptr0 + 4); - int8x16_t src2345 = vld1q_s8(sptr0 + 8); - int8x16_t max0 = vmaxq_s8(src0123, src1234); - max0 = vmaxq_s8(max0, src2345); - - src0123 = vld1q_s8(sptr1); - src1234 = vld1q_s8(sptr1 + 4); - src2345 = vld1q_s8(sptr1 + 8); - int8x16_t max1 = vmaxq_s8(src0123, src1234); - max1 = vmaxq_s8(max1, src2345); - - src0123 = vld1q_s8(sptr2); - src1234 = vld1q_s8(sptr2 + 4); - src2345 = vld1q_s8(sptr2 + 8); - int8x16_t max2 = vmaxq_s8(src0123, src1234); - max2 = vmaxq_s8(max2, src2345); - - int8x16_t max_out = vmaxq_s8(max0, max1); - max_out = vmaxq_s8(max_out, max2); - - vst1q_s8(dptr, max_out); - - sptr0 += 16; - sptr1 += 16; - sptr2 += 16; - dptr += 16; - } - for (; ow < OW; ++ow) { - int8x8_t src001 = vld1_s8(sptr0); - int8x8_t src012 = vld1_s8(sptr0 + 4); - - int8x8_t src101 = vld1_s8(sptr1); - int8x8_t src112 = vld1_s8(sptr1 + 4); - - int8x8_t src201 = vld1_s8(sptr2); - int8x8_t src212 = vld1_s8(sptr2 + 4); - int8x8_t max01_tmp = vmax_s8(src001, src101); - max01_tmp = vmax_s8(max01_tmp, src201); - - int8x8_t max12_tmp = vmax_s8(src012, src112); - max12_tmp = vmax_s8(max12_tmp, src212); -#define cb(i) \ - int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ - max12_tmp[i + 4]); -#define store(i) *(dptr + i) = dst##i; - UNROLL_CALL_NOWRAPPER(4, cb) - UNROLL_CALL_NOWRAPPER(4, store) -#undef store -#undef cb - sptr0 += 4; - sptr1 += 4; - sptr2 += 4; - dptr += 4; - } - } -} - -} // namespace arm_common -} // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h deleted file mode 100644 index 6630d1c1..00000000 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h +++ /dev/null @@ -1,25 +0,0 @@ -/** - * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#pragma once -#include "src/common/utils.h" - -namespace megdnn { -namespace arm_common { - -void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW); - -} // namespace arm_common -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp deleted file mode 100644 index ddafdbd2..00000000 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp +++ /dev/null @@ -1,112 +0,0 @@ -/** - * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/unroll_macro.h" - -namespace megdnn { -namespace arm_common { - -void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW) { - size_t oh = 0; - for (; oh < OH; ++oh) { - size_t ih = oh << 1; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - int8_t* __restrict dptr = dst + oh * OW * 4; - size_t ow = 0; - for (; ow + 3 < OW; ow += 4) { - int8x16_t src00 = vld1q_s8(sptr0); - int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); - int8x16_t src08 = vld1q_s8(sptr0 + 4 * 8); - int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), - vreinterpretq_s32_s8(src04)); - int32x4_t src0246 = src_tmp.val[0]; - int32x4_t src1357 = src_tmp.val[1]; - int32x4_t src2468 = - vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); - int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); - int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); - - int8x16_t src10 = vld1q_s8(sptr1); - int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); - int8x16_t src18 = vld1q_s8(sptr1 + 4 * 8); - - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), - vreinterpretq_s32_s8(src14)); - src0246 = src_tmp.val[0]; - src1357 = src_tmp.val[1]; - src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src18), 1); - max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); - int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); - - int8x16_t src20 = vld1q_s8(sptr2); - int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); - int8x16_t src28 = vld1q_s8(sptr2 + 4 * 8); - - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), - vreinterpretq_s32_s8(src24)); - src0246 = src_tmp.val[0]; - src1357 = src_tmp.val[1]; - src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src28), 1); - - max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); - int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); - max_tmp = vmaxq_s8(max0, max1); - int8x16_t max_out = vmaxq_s8(max_tmp, max2); - - vst1q_s8(dptr, max_out); - - sptr0 += 32; - sptr1 += 32; - sptr2 += 32; - dptr += 16; - } - for (; ow < OW; ++ow) { - int8x8_t src001 = vld1_s8(sptr0); - int8x8_t src012 = vld1_s8(sptr0 + 4); - - int8x8_t src101 = vld1_s8(sptr1); - int8x8_t src112 = vld1_s8(sptr1 + 4); - - int8x8_t src201 = vld1_s8(sptr2); - int8x8_t src212 = vld1_s8(sptr2 + 4); - int8x8_t max01_tmp = vmax_s8(src001, src101); - max01_tmp = vmax_s8(max01_tmp, src201); - - int8x8_t max12_tmp = vmax_s8(src012, src112); - max12_tmp = vmax_s8(max12_tmp, src212); -#define cb(i) \ - int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ - max12_tmp[i + 4]); -#define store(i) *(dptr + i) = dst##i; - UNROLL_CALL_NOWRAPPER(4, cb) - UNROLL_CALL_NOWRAPPER(4, store) -#undef store -#undef cb - sptr0 += 8; - sptr1 += 8; - sptr2 += 8; - dptr += 4; - } - } -} - -} // namespace arm_common -} // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp similarity index 78% rename from dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp rename to dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp index 99e97284..fa3db885 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" +#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h" +#include "src/arm_common/pooling/algo.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" @@ -19,12 +20,16 @@ namespace arm_common { void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { @@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } for (; ow < OW; ++ow) { int8x8_t src001 = vld1_s8(sptr0); - int8x8_t src012 = vld1_s8(sptr0 + 4); - int8x8_t src101 = vld1_s8(sptr1); - int8x8_t src112 = vld1_s8(sptr1 + 4); - int8x8_t max01_tmp = vmax_s8(src001, src101); - int8x8_t max12_tmp = vmax_s8(src012, src112); - int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); -#define store(i) *(dptr + i) = mat_out[i]; + int8x8_t max_out = vmax_s8(src001, src101); +#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]); UNROLL_CALL_NOWRAPPER(4, store) #undef store sptr0 += 4; @@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { @@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } for (; ow < OW; ++ow) { int8x8_t src001 = vld1_s8(sptr0); - int8x8_t src012 = vld1_s8(sptr0 + 4); - int8x8_t src101 = vld1_s8(sptr1); - int8x8_t src112 = vld1_s8(sptr1 + 4); - int8x8_t max01_tmp = vmax_s8(src001, src101); - int8x8_t max12_tmp = vmax_s8(src012, src112); - int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp); -#define store(i) *(dptr + i) = mat_out[i]; + int8x8_t max_out = vmax_s8(src001, src101); +#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]); UNROLL_CALL_NOWRAPPER(4, store) #undef store sptr0 += 8; diff --git a/dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h similarity index 87% rename from dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h rename to dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h index d6409ce2..e517406b 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h @@ -18,11 +18,13 @@ namespace arm_common { void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW); + size_t PH, size_t PW, + const WorkspaceBundle& ws); void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW); + size_t PH, size_t PW, + const WorkspaceBundle& ws); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp new file mode 100644 index 00000000..3717cc5a --- /dev/null +++ b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp @@ -0,0 +1,195 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h" +#include "src/arm_common/pooling/algo.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src0123 = vld1q_s8(sptr0); + int8x16_t src1234 = vld1q_s8(sptr0 + 4); + int8x16_t src2345 = vld1q_s8(sptr0 + 8); + int8x16_t max0 = vmaxq_s8(src0123, src1234); + max0 = vmaxq_s8(max0, src2345); + + src0123 = vld1q_s8(sptr1); + src1234 = vld1q_s8(sptr1 + 4); + src2345 = vld1q_s8(sptr1 + 8); + int8x16_t max1 = vmaxq_s8(src0123, src1234); + max1 = vmaxq_s8(max1, src2345); + + src0123 = vld1q_s8(sptr2); + src1234 = vld1q_s8(sptr2 + 4); + src2345 = vld1q_s8(sptr2 + 8); + int8x16_t max2 = vmaxq_s8(src0123, src1234); + max2 = vmaxq_s8(max2, src2345); + + int8x16_t max_out = vmaxq_s8(max0, max1); + max_out = vmaxq_s8(max_out, max2); + + vst1q_s8(dptr, max_out); + + sptr0 += 16; + sptr1 += 16; + sptr2 += 16; + dptr += 16; + } + for (; ow < OW; ++ow) { + int8x8_t src001 = vld1_s8(sptr0); + int8x8_t src012 = vld1_s8(sptr0 + 4); + + int8x8_t src101 = vld1_s8(sptr1); + int8x8_t src112 = vld1_s8(sptr1 + 4); + + int8x8_t src201 = vld1_s8(sptr2); + int8x8_t src212 = vld1_s8(sptr2 + 4); + int8x8_t max01_tmp = vmax_s8(src001, src101); + max01_tmp = vmax_s8(max01_tmp, src201); + + int8x8_t max12_tmp = vmax_s8(src012, src112); + max12_tmp = vmax_s8(max12_tmp, src212); +#define cb(i) \ + int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ + max12_tmp[i + 4]); +#define store(i) *(dptr + i) = dst##i; + UNROLL_CALL_NOWRAPPER(4, cb) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef cb + sptr0 += 4; + sptr1 += 4; + sptr2 += 4; + dptr += 4; + } + } +} + +void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); + int32x4_t src08 = vld1q_dup_s32( + reinterpret_cast(sptr0 + 4 * 8)); + int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), + vreinterpretq_s32_s8(src04)); + int32x4_t src0246 = src_tmp.val[0]; + int32x4_t src1357 = src_tmp.val[1]; + int32x4_t src2468 = vextq_s32(src0246, src08, 1); + int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); + int32x4_t src18 = vld1q_dup_s32( + reinterpret_cast(sptr1 + 4 * 8)); + + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), + vreinterpretq_s32_s8(src14)); + src0246 = src_tmp.val[0]; + src1357 = src_tmp.val[1]; + src2468 = vextq_s32(src0246, src18, 1); + max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); + int32x4_t src28 = vld1q_dup_s32( + reinterpret_cast(sptr2 + 4 * 8)); + + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), + vreinterpretq_s32_s8(src24)); + src0246 = src_tmp.val[0]; + src1357 = src_tmp.val[1]; + src2468 = vextq_s32(src0246, src28, 1); + + max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + max_tmp = vmaxq_s8(max0, max1); + int8x16_t max_out = vmaxq_s8(max_tmp, max2); + + vst1q_s8(dptr, max_out); + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + dptr += 16; + } + for (; ow < OW; ++ow) { + int8x8_t src001 = vld1_s8(sptr0); + int8x8_t src012 = vld1_s8(sptr0 + 4); + + int8x8_t src101 = vld1_s8(sptr1); + int8x8_t src112 = vld1_s8(sptr1 + 4); + + int8x8_t src201 = vld1_s8(sptr2); + int8x8_t src212 = vld1_s8(sptr2 + 4); + int8x8_t max01_tmp = vmax_s8(src001, src101); + max01_tmp = vmax_s8(max01_tmp, src201); + + int8x8_t max12_tmp = vmax_s8(src012, src112); + max12_tmp = vmax_s8(max12_tmp, src212); +#define cb(i) \ + int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ + max12_tmp[i + 4]); +#define store(i) *(dptr + i) = dst##i; + UNROLL_CALL_NOWRAPPER(4, cb) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef cb + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + dptr += 4; + } + } +} + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h similarity index 58% rename from dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h rename to dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h index a07f3dcb..09c1ceb9 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h @@ -15,9 +15,16 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, +void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, + size_t OH, size_t OW, + size_t PH, size_t PW, + const WorkspaceBundle& ws); + +void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW); + size_t OW, size_t PH, size_t PW, + const WorkspaceBundle& ws); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp similarity index 77% rename from dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp rename to dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp index b6c3493a..e68af81b 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" +#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h" +#include "src/arm_common/pooling/algo.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" @@ -19,14 +20,18 @@ namespace arm_common { void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { @@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; + const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { - int8x16_t src00, src04, src08, src09, max_tmp0, max_tmp1, max_tmp2, - max_tmp3; - int32x4_t src0246, src1357, src2468, src3579; + int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3; + int32x4_t src0246, src1357, src2468, src3579, src08, src09; int32x4x2_t src_tmp; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_s8(sptr##i + 4 * 8); \ - src09 = vld1q_s8(sptr##i + 4 * 9); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ - src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ - max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ - vreinterpretq_s8_s32(src1357)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ + vreinterpretq_s8_s32(src1357)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) diff --git a/dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h similarity index 92% rename from dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h rename to dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h index 581e8f93..166c99a6 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h @@ -18,7 +18,7 @@ namespace arm_common { #define KERN(strdie) \ void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ - size_t OW, size_t PH, size_t PW); + size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); KERN(stride1) KERN(stride2) diff --git a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp similarity index 67% rename from dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp rename to dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp index a1562274..e1d19519 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" +#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h" +#include "src/arm_common/pooling/algo.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" @@ -19,15 +20,19 @@ namespace arm_common { void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; - const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; + const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; + const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { @@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, max_out = vmax_s8(max_out, max_tmp3); max_out = vmax_s8(max_out, max_tmp4); -#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); +#define COMPARE_SRC45(i) \ + int32x2_t src##i##_45 = \ + vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) - int8x8_t max_45 = vmax_s8(src0_45, src1_45); - max_45 = vmax_s8(max_45, src1_45); - max_45 = vmax_s8(max_45, src2_45); - max_45 = vmax_s8(max_45, src3_45); - max_45 = vmax_s8(max_45, src4_45); + int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), + vreinterpret_s8_s32(src1_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45)); #define store(i) \ *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); @@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW) { + size_t PH, size_t PW, + const WorkspaceBundle& ws) { + const int8_t* sptr = nullptr; + size_t IH2, IW2; + sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws); size_t oh = 0; for (; oh < OH; ++oh) { size_t ih = oh << 1; - const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; - const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; - const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; - const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; - const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; + const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4; + const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4; + const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4; + const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4; + const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4; int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { - int8x16_t src00, src04, src08, src09, src10, max_tmp0, max_tmp1, - max_tmp2, max_tmp3, max_tmp4; - int32x4_t src0246, src1357, src2468, src3579, src46810; + int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3, + max_tmp4; + int32x4_t src0246, src1357, src2468, src3579, src46810, src10, + src09, src08; int32x4x2_t src_tmp; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_s8(sptr##i + 4 * 8); \ - src09 = vld1q_s8(sptr##i + 4 * 9); \ - src10 = vld1q_s8(sptr##i + 4 * 10); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ - src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ - src46810 = vextq_s32(src2468, vreinterpretq_s32_s8(src10), 1); \ - max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ - vreinterpretq_s8_s32(src1357)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ + vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src46810 = vextq_s32(src2468, src10, 1); \ + max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ + vreinterpretq_s8_s32(src1357)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) @@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, max_out = vmax_s8(max_out, max_tmp3); max_out = vmax_s8(max_out, max_tmp4); -#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); +#define COMPARE_SRC45(i) \ + int32x2_t src##i##_45 = \ + vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) - int8x8_t max_45 = vmax_s8(src0_45, src1_45); - max_45 = vmax_s8(max_45, src1_45); - max_45 = vmax_s8(max_45, src2_45); - max_45 = vmax_s8(max_45, src3_45); - max_45 = vmax_s8(max_45, src4_45); + int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), + vreinterpret_s8_s32(src1_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); + max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45)); #define store(i) \ *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); diff --git a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h similarity index 92% rename from dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h rename to dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h index 3c9c1256..73221f1f 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.h +++ b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h @@ -18,7 +18,7 @@ namespace arm_common { #define KERN(strdie) \ void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ - size_t OW, size_t PH, size_t PW); + size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws); KERN(stride1) KERN(stride2) diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 04b93027..f7556ee1 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; - AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; - AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4; + AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4; AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; @@ -41,8 +40,7 @@ public: all_algos.emplace_back(&algo_filter5_max_stride2); all_algos.emplace_back(&algo_int8_filter2_max_stride2); all_algos.emplace_back(&algo_int8_filter3_max_stride2); - all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); - all_algos.emplace_back(&algo_filter3_max_stride1_nchw4); + all_algos.emplace_back(&algo_filter3_max_stridex_nchw4); all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); @@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, arm_common_workspace = ws.total_size_in_bytes() * nr_threads; } + if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) && + (param.format == param::Pooling::Format::NCHW44)) { + WorkspaceBundle ws = get_bundle_nchw44(param); + arm_common_workspace = ws.total_size_in_bytes() * nr_threads; + } + if (find_algo) { return arm_common_workspace; } else { diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 9e716a46..2bb8f992 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -84,8 +84,7 @@ private: class AlgoInt8Filter2MaxStride2; class AlgoInt8Filter3MaxStride2; class AlgoFilter2MaxStridexNCHW44; - class AlgoFilter3MaxStride2NCHW44; - class AlgoFilter3MaxStride1NCHW44; + class AlgoFilter3MaxStridexNCHW44; class AlgoFilter4MaxStridexNCHW44; class AlgoFilter5MaxStridexNCHW44; class AlgoPack; diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index 55785035..3b9bf576 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -8,8 +8,6 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megdnn/dtype.h" -#include "megdnn/opr_param_defs.h" #include "test/arm_common/fixture.h" #include "test/common/pooling.h" @@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2) // clang-format on } -TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {3, 5, 10}) - for (size_t iw: {3, 5, 7, 9, 15, 20}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 3 && iw+2*pw >= 3) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 3; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} - -TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44) -{ - // clang-format off - for (size_t ih: {3, 5, 10}) - for (size_t iw: {3, 5, 7, 9, 15, 20}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 3 && iw+2*pw >= 3) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 1; - param.window_h = param.window_w = 3; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} - -TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S1x1_NCHW44) -{ - // clang-format off - for (size_t ih: {2, 5, 10, 17}) - for (size_t iw: {2, 6, 8, 16, 26}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 2 && iw+2*pw >= 2) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 1; - param.window_h = param.window_w = 2; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {2, 5, 10, 17}) - for (size_t iw: {2, 6, 8, 16, 26}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 2 && iw+2*pw >= 2) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 2; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S1x1_NCHW44) -{ - // clang-format off - for (size_t ih: {4, 7, 10, 17, 20}) - for (size_t iw: {4, 8, 10, 21, 32}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 2 && iw+2*pw >= 2) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 1; - param.window_h = param.window_w = 4; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {4, 10, 18, 25, 30}) - for (size_t iw: {4, 12, 17, 20, 25}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 2 && iw+2*pw >= 2) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 4; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S1x1_NCHW44) -{ - // clang-format off - for (size_t ih: {5, 9, 19, 20, 39}) - for (size_t iw: {5, 12, 23, 27, 39}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 5 && iw+2*pw >= 5) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 1; - param.window_h = param.window_w = 5; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {5, 9, 19, 20, 39}) - for (size_t iw: {5, 12, 23, 27, 39}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 5 && iw+2*pw >= 5) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 5; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON, POOLING_FP16) { Checker checker(handle()); diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index 4b34fb98..1543ae10 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -8,6 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include +#include "megdnn/dtype.h" #include "test/arm_common/fixture.h" #include "test/common/pooling.h" @@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { checker.set_param(param).exec({{2, 3, ih, iw}, {}}); } } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {3, 5, 10}) - for (size_t iw: {3, 5, 7, 9, 15, 20}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 3 && iw+2*pw >= 3) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; - param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 3; - checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44) { // clang-format off for (size_t ih: {3, 5, 10}) for (size_t iw: {3, 5, 7, 9, 15, 20}) - for (size_t ph: {0}) - for (size_t pw: {0}) + for (size_t ph: {0, 1, 2}) + for (size_t pw: {0, 1, 2}) if (ih+2*ph >= 3 && iw+2*pw >= 3) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) param.stride_h = param.stride_w = 1; param.window_h = param.window_w = 3; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44) { // clang-format off - for (size_t ih: {2, 5, 10, 17}) - for (size_t iw: {2, 6, 8, 16, 26}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 3 && iw+2*pw >= 3) + for (size_t ih: {2, 5, 10, 17}) + for (size_t iw: {2, 6, 8, 16, 26}) + for (size_t ph: {0, 1}) + for (size_t pw: {0, 1}) + if (ih+2*ph >= 2 && iw+2*pw >= 2) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; Checker checker(handle()); @@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) param.stride_h = param.stride_w = 1; param.window_h = param.window_w = 2; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {2, 5, 10, 17}) - for (size_t iw: {2, 6, 8, 16, 26}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 3 && iw+2*pw >= 3) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 2; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); } // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44) + +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44) { // clang-format off - for (size_t ih: {4, 7, 10, 17, 20}) - for (size_t iw: {4, 8, 10, 21, 32}) - for (size_t ph: {0}) - for (size_t pw: {0}) + for (size_t ih: {4, 10, 18, 25, 30}) + for (size_t iw: {4, 12, 17, 20, 25}) + for (size_t ph: {0, 1, 2}) + for (size_t pw: {0, 1, 2}) if (ih+2*ph >= 4 && iw+2*pw >= 4) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44) param.stride_h = param.stride_w = 1; param.window_h = param.window_w = 4; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {4, 10, 18, 25, 30}) - for (size_t iw: {4, 12, 17, 20, 25}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 4 && iw+2*pw >= 4) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 4; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); } // clang-format on } -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44) { // clang-format off for (size_t ih: {5, 9, 19, 20, 39}) for (size_t iw: {5, 12, 23, 27, 39}) - for (size_t ph: {0}) - for (size_t pw: {0}) + for (size_t ph: {0, 1, 2}) + for (size_t pw: {0, 1, 2}) if (ih+2*ph >= 5 && iw+2*pw >= 5) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; @@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) param.stride_h = param.stride_w = 1; param.window_h = param.window_w = 5; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); - } - // clang-format on -} -TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S2x2_NCHW44) -{ - // clang-format off - for (size_t ih: {5, 9, 19, 20, 39}) - for (size_t iw: {5, 12, 23, 27, 39}) - for (size_t ph: {0}) - for (size_t pw: {0}) - if (ih+2*ph >= 5 && iw+2*pw >= 5) - { - UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; - Checker checker(handle()); - checker.set_dtype(0, dtype::QuantizedS8(1.1f)); - checker.set_rng(0,&rng); - param::Pooling param; - param.mode = param::Pooling::Mode::MAX; - param.format = param::Pooling::Format::NCHW44; - param.pad_h = ph; - param.pad_w = pw; param.stride_h = param.stride_w = 2; - param.window_h = param.window_w = 5; checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } // clang-format on } @@ -473,13 +391,15 @@ template void benchmark_impl(const typename Opr::Param& param, std::vector> shapes, size_t RUNS, TaskExecutorConfig&& multi_thread_config, - TaskExecutorConfig&& single_thread_config) { + TaskExecutorConfig&& single_thread_config, + DType data_type) { std::vector multi_thread_times, single_thread_times; { auto multi_thread_hanle = create_cpu_handle(0, true, &multi_thread_config); auto benchmarker = Benchmarker(multi_thread_hanle.get()); benchmarker.set_times(RUNS).set_display(false).set_param(param); + benchmarker.set_dtype(0, data_type); for (auto shape : shapes) { multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); } @@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param, create_cpu_handle(0, true, &single_thread_config); auto benchmarker = Benchmarker(single_thread_handle.get()); benchmarker.set_times(RUNS).set_display(false).set_param(param); + benchmarker.set_dtype(0, data_type); for (auto shape : shapes) { single_thread_times.push_back(benchmarker.exec(shape) / RUNS); } @@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) { param.stride_h = param.stride_w = 2; param.pad_h = param.pad_w = 1; printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, - param.stride_h, param.pad_h, static_cast(param.mode)); - benchmark_impl(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}); - benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}); - benchmark_impl(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}); + param.window_w, param.stride_h, static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}, dtype::Float32()); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, dtype::Float32()); + benchmark_impl(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}, dtype::Float32()); +} + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) { + constexpr size_t RUNS = 50; + + using Param = param::Pooling; + Param param; + param.pad_h = param.pad_w = 0; + param.mode = Param::Mode::MAX; + std::vector> shapes; + std::vector> filter_and_stride = { + {2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}}; + for(auto filter:filter_and_stride){ + + shapes.push_back({{1, 32 * 4, 215, 215}, {}}); + shapes.push_back({{1, 32 * 4, 128, 128}, {}}); + shapes.push_back({{1, 16 * 4, 56, 56}, {}}); + + param.window_h = param.window_w = filter[0]; + param.stride_h = param.stride_w = filter[1]; + param.format = Param::Format::NCHW; + printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, + param.window_h, param.stride_h, static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + dtype::QuantizedS8(1.1f)); + shapes.clear(); + shapes.push_back({{1, 32, 215, 215,4}, {}}); + shapes.push_back({{1, 32, 128, 128,4}, {}}); + shapes.push_back({{1, 16, 56, 56, 4}, {}}); + + param.format = Param::Format::NCHW44; + printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, + param.window_w, param.stride_h, static_cast(param.mode)); + benchmark_impl(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + dtype::QuantizedS8(1.1f)); + shapes.clear(); + } } #endif -- GitLab