From 2b4b4d66d923f1cfad9ec34a129a6e695e079f67 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 12 Jun 2020 16:43:07 +0800 Subject: [PATCH] feat(dnn/fallback): add aarch64 mk4 dot 3x3 s1 fuse packb GitOrigin-RevId: 3e69878d8d349d3cd21d828a3029aa7e1c61a294 --- dnn/src/fallback/conv_bias/im2col/factory.h | 2 +- .../im2col/strategy_fuse_nchw44_dot.cpp | 15 ++++++ .../arm_common/conv_bias_multi_thread.cpp | 19 +++++++ .../conv_bias_multi_thread_benchmark.cpp | 53 +++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index 0dd502244..8915783ca 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -450,7 +450,7 @@ Strategy* StrategyDelegationStorage::get( sparam.kernel = param.filter_meta.spatial[0]; sparam.stride = param.filter_meta.stride[0]; sparam.is_square = - param.filter_meta.spatial[0] == param.filter_meta.spatial[0]; + param.filter_meta.spatial[0] == param.filter_meta.spatial[1]; sparam.is_xcorr = param.filter_meta.should_flip; MEGDNN_LOCK_GUARD(m_mtx); if (map_strategys.find(sparam) == map_strategys.end()) { diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp new file mode 100644 index 000000000..c587f3ca5 --- /dev/null +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/fallback/conv_bias/im2col/strategy_base.h" + + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index a152a7f90..5e606b1a0 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1764,7 +1764,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #undef cb } #endif +#endif +#endif + +#if MEGDNN_AARCH64 +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) { + UniformIntRNG rng{-50, 50}; +#define cb(name) \ + checker_conv_bias( \ + get_nchw44_conv_bias_args({3}, 1, false, false, false, false, \ + true, false, false, false), \ + handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); + float epsilon = 0.001; + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); +#undef cb +} #endif #endif diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 945f6cc93..d2743bc1f 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -655,6 +655,59 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { + constexpr size_t RUNS = 40; + std::vector data_type = { + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group, size_t P, size_t S, + bool is_nchw = false) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::DENSE; + param.format = param::ConvBias::Format::NCHW44_DOT; + auto OH = (H + 2 * P - FS) / static_cast(S) + 1; + auto OW = (W + 2 * P - FS) / static_cast(S) + 1; + TensorShape src = {N, IC / 4, H, W, 4}; + TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4}; + if (group > 1) { + filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4}; + param.sparse = param::ConvBias::Sparse::GROUP; + } + if (is_nchw) { + src = {N, IC, H, W}; + filter = {OC / 4, FS, FS, IC, 4}; + } + TensorShape bias = {1, OC / 4, 1, 1, 4}; + TensorShape dst = {N, OC / 4, OH, OW, 4}; + + SmallVector shapes{src, filter, bias, {}, dst}; + float computations = + (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); + bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); + bench_case(1, 256, 256, 14, 14, 3, 1, 1, 1); + bench_case(1, 512, 512, 7, 7, 3, 1, 1, 1); + + bench_case(1, 64, 64, 56, 56, 3, 4, 1, 1); + bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1); + bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1); + bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1); + +} + TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { constexpr size_t RUNS = 50; -- GitLab