diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 4cc52252aad178cf16b7664c8f19a21fe26543d6..061c80836f6860631755ea6bc51ae210bb9c7e34 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -10,7 +10,9 @@ * implied. */ #include "src/arm_common/pooling/algo.h" +#include "megdnn/opr_param_defs.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" @@ -556,6 +558,60 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( } MIDOUT_END(); } + +bool PoolingImpl::AlgoFilter3MaxStride2NCHW44::usable( + const PoolingKernSizeParam& param) const { + auto SH = param.stride[0]; + auto SW = param.stride[1]; + auto FH = param.filter[0]; + auto FW = param.filter[1]; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.format == Param::Format::NCHW44 && + param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && + SW == 2 && PH == 0 && PW == 0; + return avaible; +} + +void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( + const PoolingKernParam& param) const { + auto IH = param.isz[0], IW = param.isz[1]; + auto OH = param.osz[0], OW = param.osz[1]; + auto N = param.n, C = param.ic; + auto PH = param.padding[0]; + auto PW = param.padding[1]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(type, func, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ + midout_iv(midout_type_id)) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ + size_t index, size_t thread_id) { \ + MEGDNN_MARK_USED_VAR(thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_3x3_s2x2_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ + run); \ + } \ + MIDOUT_END(); + + DISPATCH_FUNC(int8_t, int8, 9); + +#undef DISPATCH_FUNC +} + } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index a99695906ea0918f95e82edbb4b34b00f9aaa741..b4c3291f5113f6c94d9b84ef57ef83fc32f2c56f 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -82,6 +82,15 @@ public: bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; + +class PoolingImpl::AlgoFilter3MaxStride2NCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE2_NCHW44"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; + WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); } // namespace arm_common diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddafdbd2f72ebe13929d005a5022eaba22421699 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp @@ -0,0 +1,112 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW) { + size_t oh = 0; + for (; oh < OH; ++oh) { + size_t ih = oh << 1; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; + int8_t* __restrict dptr = dst + oh * OW * 4; + size_t ow = 0; + for (; ow + 3 < OW; ow += 4) { + int8x16_t src00 = vld1q_s8(sptr0); + int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); + int8x16_t src08 = vld1q_s8(sptr0 + 4 * 8); + int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), + vreinterpretq_s32_s8(src04)); + int32x4_t src0246 = src_tmp.val[0]; + int32x4_t src1357 = src_tmp.val[1]; + int32x4_t src2468 = + vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); + int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + + int8x16_t src10 = vld1q_s8(sptr1); + int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); + int8x16_t src18 = vld1q_s8(sptr1 + 4 * 8); + + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), + vreinterpretq_s32_s8(src14)); + src0246 = src_tmp.val[0]; + src1357 = src_tmp.val[1]; + src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src18), 1); + max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + + int8x16_t src20 = vld1q_s8(sptr2); + int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); + int8x16_t src28 = vld1q_s8(sptr2 + 4 * 8); + + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), + vreinterpretq_s32_s8(src24)); + src0246 = src_tmp.val[0]; + src1357 = src_tmp.val[1]; + src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src28), 1); + + max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), + vreinterpretq_s8_s32(src1357)); + int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); + max_tmp = vmaxq_s8(max0, max1); + int8x16_t max_out = vmaxq_s8(max_tmp, max2); + + vst1q_s8(dptr, max_out); + + sptr0 += 32; + sptr1 += 32; + sptr2 += 32; + dptr += 16; + } + for (; ow < OW; ++ow) { + int8x8_t src001 = vld1_s8(sptr0); + int8x8_t src012 = vld1_s8(sptr0 + 4); + + int8x8_t src101 = vld1_s8(sptr1); + int8x8_t src112 = vld1_s8(sptr1 + 4); + + int8x8_t src201 = vld1_s8(sptr2); + int8x8_t src212 = vld1_s8(sptr2 + 4); + int8x8_t max01_tmp = vmax_s8(src001, src101); + max01_tmp = vmax_s8(max01_tmp, src201); + + int8x8_t max12_tmp = vmax_s8(src012, src112); + max12_tmp = vmax_s8(max12_tmp, src212); +#define cb(i) \ + int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ + max12_tmp[i + 4]); +#define store(i) *(dptr + i) = dst##i; + UNROLL_CALL_NOWRAPPER(4, cb) + UNROLL_CALL_NOWRAPPER(4, store) +#undef store +#undef cb + sptr0 += 8; + sptr1 += 8; + sptr2 += 8; + dptr += 4; + } + } +} + +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h new file mode 100644 index 0000000000000000000000000000000000000000..a07f3dcb0ffd249eb7b78dc0ffb9deb28167d005 --- /dev/null +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h @@ -0,0 +1,25 @@ +/** + * \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/common/utils.h" + +namespace megdnn { +namespace arm_common { + +void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, + size_t IH, size_t IW, size_t OH, + size_t OW, size_t PH, size_t PW); + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index ce3cb723bbf54044480d917269df723996fa0bae..1e27ea15965a5d95576193a02ec177b642cac9ce 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -25,6 +25,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoFilter5MaxStride2 algo_filter5_max_stride2; AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; + AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; public: AlgoPack() { @@ -36,6 +37,7 @@ public: all_algos.emplace_back(&algo_filter5_max_stride2); all_algos.emplace_back(&algo_int8_filter2_max_stride2); all_algos.emplace_back(&algo_int8_filter3_max_stride2); + all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 9b7159916dd0aab09bea4ca279f70eb0e93a9a8d..b7bde28798cad73b19aaf431bf676eab90dcb601 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -83,6 +83,7 @@ private: class AlgoFilter5MaxStride2; class AlgoInt8Filter2MaxStride2; class AlgoInt8Filter3MaxStride2; + class AlgoFilter3MaxStride2NCHW44; class AlgoPack; }; } // namespace arm_common diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index 3b9bf57662014a8412e2eb2a9ce46802ff862904..8822824c786ff62f94ffa560c53431320e52209c 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -8,6 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megdnn/dtype.h" +#include "megdnn/opr_param_defs.h" #include "test/arm_common/fixture.h" #include "test/common/pooling.h" @@ -100,6 +102,32 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2) // clang-format on } +TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S2x2_NCHW44) +{ + // clang-format off + for (size_t ih: {3, 5, 10}) + for (size_t iw: {3, 5, 7, 9, 15, 20}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON, POOLING_FP16) { Checker checker(handle()); diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index d9fc9224f5ed8865639e98e74eede087efa16b4f..e9dd7d9ca53716a0d75cbea48841b3dfc4b12514 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -53,7 +53,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { checker.set_param(param).exec({{2, 3, ih, iw}, {}}); } } +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S2x2_NCHW44) +{ + // clang-format off + for (size_t ih: {3, 5, 10}) + for (size_t iw: {3, 5, 7, 9, 15, 20}) + for (size_t ph: {0}) + for (size_t pw: {0}) + if (ih+2*ph >= 3 && iw+2*pw >= 3) + { + UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; + Checker checker(handle()); + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0,&rng); + param::Pooling param; + param.mode = param::Pooling::Mode::MAX; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = 2; + param.window_h = param.window_w = 3; + checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); + } + // clang-format on +} TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2) { for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55})