From 3afa3893d7e73dcaef682aa6178d10d3a38b307b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 27 Jul 2021 13:15:25 +0800 Subject: [PATCH] perf(arm_common): optimize arm common pooling 9x9 and 13x13 GitOrigin-RevId: 33d5a624784a5dde61b6c9cfe461297a0f2950fe --- dnn/src/arm_common/intrinsic_helper.h | 3 ++- .../pooling/algo_fp32_pooling_nchw44.cpp | 27 +++++++++++++++---- .../pooling/kern_fp32_pooling_nchw44.h | 4 ++- dnn/test/arm_common/pooling_multi_thread.cpp | 25 +++++++++++++++++ .../cross_build_android_arm_inference.sh | 2 +- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/dnn/src/arm_common/intrinsic_helper.h b/dnn/src/arm_common/intrinsic_helper.h index 6db691e63..b29f3e3c9 100644 --- a/dnn/src/arm_common/intrinsic_helper.h +++ b/dnn/src/arm_common/intrinsic_helper.h @@ -124,4 +124,5 @@ __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { } // namespace } // namespace megdnn #undef __ai -// vim: syntax=cpp.doxygen \ No newline at end of file + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp index c8d6a9c81..18c058da4 100644 --- a/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp +++ b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp @@ -30,10 +30,12 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable( bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && - fh == fw && sh == sw && - (fh == 2 || fh == 3 || fh == 4 || fh == 5) && - (sh == 1 || sh == 2); - return avaible; + fh == fw && sh == sw; + bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) && + (sh == 1 || sh == 2)); + size_ok |= ((fh == 9 || fh == 13) && (sh == 1)); + + return avaible && size_ok; } void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( @@ -94,6 +96,15 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( megdnn_assert(0, "invalid stride %d", sh); \ } +#define DISPATCH_STRIDE_1(filter) \ + switch (sh) { \ + case 1: \ + DISPATCH_MODE(filter, 1); \ + break; \ + default: \ + megdnn_assert(0, "invalid stride %d", sh); \ + } + #define DISPATCH_FILTER() \ switch (fh) { \ case 2: \ @@ -108,6 +119,12 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( case 5: \ DISPATCH_STRIDE(5); \ break; \ + case 9: \ + DISPATCH_STRIDE_1(9); \ + break; \ + case 13: \ + DISPATCH_STRIDE_1(13); \ + break; \ default: \ megdnn_assert(0, "invalid filter %d", fh); \ } @@ -123,4 +140,4 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( } // namespace arm_common } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h index a4050c39a..4a6ad128d 100644 --- a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h +++ b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h @@ -64,6 +64,8 @@ INSTANCE_CAL(2) INSTANCE_CAL(3) INSTANCE_CAL(4) INSTANCE_CAL(5) +INSTANCE_CAL(9) +INSTANCE_CAL(13) #undef INSTANCE_CAL #undef CALCULATE_AVG_CB @@ -305,4 +307,4 @@ static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst, } // namespace arm_common } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index 2532a80e8..deb2ef467 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -116,6 +116,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { } } +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W9_w13_NCHW44) +{ + UniformIntRNG rng{-10, 10}; + Checker checker(handle()); + checker.set_rng(0, &rng); + // clang-format off + for (size_t ih: {20, 15}) + for (size_t iw: {15, 20}) + for (size_t kernel: {9, 13}) + for (size_t pad: {4, 6}) + for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) + if (kernel > pad) + { + param::Pooling param; + param.mode = mode; + param.format = param::Pooling::Format::NCHW44; + param.pad_h = pad; + param.pad_w = pad; + param.stride_h = param.stride_w = 1; + param.window_h = param.window_w = kernel ; + checker.set_param(param).exec(TensorShapeArray{{2, 8, ih, iw, 4}, {}}); + } + // clang-format on +} + TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) { UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; diff --git a/scripts/cmake-build/cross_build_android_arm_inference.sh b/scripts/cmake-build/cross_build_android_arm_inference.sh index 1f5236e91..1e360046b 100755 --- a/scripts/cmake-build/cross_build_android_arm_inference.sh +++ b/scripts/cmake-build/cross_build_android_arm_inference.sh @@ -2,7 +2,7 @@ set -e ARCHS=("arm64-v8a" "armeabi-v7a") -BUILD_TYPE=Release +BUILD_TYPE=RelWithDebInfo MGE_ARMV8_2_FEATURE_FP16=OFF MGE_DISABLE_FLOAT16=OFF ARCH=arm64-v8a -- GitLab