From 3597a6dbd72ec0701ffe2fc681eff776c411bc91 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Jun 2021 21:18:09 +0800 Subject: [PATCH] feat(dnn/arm): nchw_nchw44 conv support 1x1s1 GitOrigin-RevId: 8c8f7d7c763b603961ca27b1fd17425bafd019cd --- .../dot_direct_nchw_nchw44_s1.cpp | 47 +++++++++++++++++++ .../dot_direct_nchw_nchw44_s2.cpp | 12 +++++ .../int8_direct_nchw_nchw44_s1.cpp | 42 +++++++++++++++++ .../int8_direct_nchw_nchw44_s2.cpp | 10 ++++ .../int8/direct_nchw_nchw44_algo.cpp | 3 ++ .../int8/dot_direct_nchw_nchw44_algo.cpp | 3 ++ dnn/src/common/nchw_nchwxx_valid.h | 16 ++++--- .../arm_common/conv_bias_multi_thread.cpp | 16 +++++++ .../conv_bias_multi_thread_benchmark.cpp | 1 + src/plugin/impl/opr_footprint.cpp | 3 +- 10 files changed, 146 insertions(+), 7 deletions(-) diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp index 170de66b9..565e7cff7 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp @@ -47,6 +47,52 @@ struct ShiftCalHelper { } }; ////////////////////stride 1/////////////////// + +template +struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 1; + constexpr int filter_width = 4; + constexpr int weight_reg = 2; + constexpr int src_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw * pack_iw_len, 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + template struct KerNeonDotXXs2Nchw44Int8 { } }; +template +struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") + static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, + int, int, int, const Op&) { + megdnn_assert(0, "not impl"); + } +}; + template struct KerNeonDotXXs2Nchw44Int8 { static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); }; +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 1; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + + template struct KerNeonXXs2NchwNchw44 { static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, @@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44 { INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) #define INSTANCE_CONV_KERN(stride) \ + INSTANCE_BIAS_MODE_PARAM(stride, 1) \ INSTANCE_BIAS_MODE_PARAM(stride, 2) \ INSTANCE_BIAS_MODE_PARAM(stride, 3) \ INSTANCE_BIAS_MODE_PARAM(stride, 5) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp index ea1273722..23345c0e9 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp @@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44 { } }; +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, + int, int, int, const Op&) { + megdnn_assert(0, "not impl nchw_nchw44 1x1 s2"); + } +}; + enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; template MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, @@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44 { INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) #define INSTANCE_CONV_KERN(stride) \ + INSTANCE_BIAS_MODE_PARAM(stride, 1) \ INSTANCE_BIAS_MODE_PARAM(stride, 2) \ INSTANCE_BIAS_MODE_PARAM(stride, 3) \ INSTANCE_BIAS_MODE_PARAM(stride, 5) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp index a0f44b56c..be29f3450 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( #define DISPATCH_CONV_KERN(stride) \ switch (param.filter_meta.spatial[0]) { \ + case 1: \ + GET_BIAS_MODE_PARAM(stride, 1) \ + break; \ case 2: \ GET_BIAS_MODE_PARAM(stride, 2) \ break; \ diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 7c579e9a6..bfa8ac8bb 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( #define DISPATCH_CONV_KERN(stride) \ switch (param.filter_meta.spatial[0]) { \ + case 1: \ + GET_BIAS_MODE_PARAM(stride, 1) \ + break; \ case 2: \ GET_BIAS_MODE_PARAM(stride, 2) \ break; \ diff --git a/dnn/src/common/nchw_nchwxx_valid.h b/dnn/src/common/nchw_nchwxx_valid.h index 4402dc871..df527a235 100644 --- a/dnn/src/common/nchw_nchwxx_valid.h +++ b/dnn/src/common/nchw_nchwxx_valid.h @@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid( nonline_mode == param::ConvBias::NonlineMode::H_SWISH; bool ok_src_dst = fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; - bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || - fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_filter = + fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7 || + (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[1] == 2); @@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid( nonline_mode == param::ConvBias::NonlineMode::H_SWISH; bool ok_src_dst = fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; - bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || - fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_filter = + fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7 || + (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[1] == 2); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index e76ef63cc..9e56345dc 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { handle(), "S8_CONV_NCHW_NCHW44"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) { + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, + false, true), + handle(), "S8_CONV_NCHW_NCHW44"); +} + /*****************************quint8 direct****************************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( @@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) { + auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, + 1, false, true); + for (auto&& arg : args) { + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + } + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( {2, 3, 5, 7}, 1, false, false, false), diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 3d4d4ff67..b50937b2c 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, data_type); }; + bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true); bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 40d981343..752806bb1 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, if (param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44_DOT) { //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} - if (filter_shape[1] == 1 && filter_shape[2] == 1) { + if (filter_shape[1] == 1 && filter_shape[2] == 1 && + filter_shape.ndim == 6) { group *= 4; } size_t computation = dst_shape.total_nr_elems() * fh * fw * -- GitLab