From 3117bfb73870d437dce17f983db405107754d57e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 8 Jun 2020 19:33:21 +0800 Subject: [PATCH] fix(dnn/arm): nchw44 direct int8 support 8832 GitOrigin-RevId: 696fa05d943b28fcec3a236bb8518fb255eae9db --- dnn/src/arm_common/conv_bias/int8/algos.h | 29 +- dnn/src/arm_common/conv_bias/int8/direct.h | 20 - ...nchw44_algo.cpp => direct_nchw44_algo.cpp} | 285 ++-- .../conv_bias/int8/direct_nchw44_kern.h | 1428 +++++++++++++++++ .../int8/direct_stride1_nchw44_algo.cpp | 393 ----- .../int8/direct_stride1_nchw44_kern.cpp | 791 --------- .../int8/direct_stride2_nchw44_kern.cpp | 793 --------- dnn/src/arm_common/conv_bias/opr_impl.cpp | 6 +- dnn/src/arm_common/conv_bias/opr_impl.h | 3 +- .../arm_common/elemwise_helper/kimpl/none.h | 2 + dnn/test/arm_common/conv_bias.cpp | 10 +- .../arm_common/conv_bias_multi_thread.cpp | 85 +- 12 files changed, 1645 insertions(+), 2200 deletions(-) rename dnn/src/arm_common/conv_bias/int8/{direct_stride2_nchw44_algo.cpp => direct_nchw44_algo.cpp} (58%) create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 47cd7b9b..31849f3c 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -38,23 +38,6 @@ public: const NCBKernSizeParam& param) const override; }; -class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase { -public: - AlgoS8DirectStride1NCHW44() {} - bool is_reproducible() const override { return true; } - const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; } - bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; - size_t get_workspace(fallback::ConvBiasImpl*, - const NCBKernSizeParam& param) const override; - virtual SmallVector dispatch_kerns( - fallback::ConvBiasImpl* opr, - const NCBKernSizeParam& param) const override; - - bool is_preferred(megdnn::fallback::ConvBiasImpl*, - const NCBKernSizeParam& param) const override; -}; - class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { bool m_large_group; @@ -74,11 +57,11 @@ public: const NCBKernSizeParam& param) const override; }; -class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase { +class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { public: - AlgoS8DirectStride2NCHW44() {} + AlgoS8DirectNCHW44() {} bool is_reproducible() const override { return true; } - const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; } + const char* name() const override { return "S8_NCHW44_DIRECT"; } bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(fallback::ConvBiasImpl*, @@ -245,8 +228,8 @@ private: //=======================input int8 compute fp32 output int8============ class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { public: - AlgoS8CF32WinogradF23_4x4_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoS8CF32WinogradF23_4x4_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} bool is_reproducible() const override { return true; } const char* name() const override { @@ -277,7 +260,7 @@ private: class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { public: AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} bool is_reproducible() const override { return true; } const char* name() const override { diff --git a/dnn/src/arm_common/conv_bias/int8/direct.h b/dnn/src/arm_common/conv_bias/int8/direct.h index 1b0589e7..1a58ab66 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct.h +++ b/dnn/src/arm_common/conv_bias/int8/direct.h @@ -36,26 +36,6 @@ KERN(stride2, 7, nchw) #undef KERN -#define KERN(stride, i, layout) \ - template \ - void conv_direct_##stride##_##i##x##i##_int8_##layout( \ - const int8_t* src, const int8_t* filter, const int32_t* bias, \ - int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ - const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW, const Op& op); -KERN(stride1, 2, nchw44) -KERN(stride1, 3, nchw44) -KERN(stride1, 5, nchw44) -KERN(stride1, 7, nchw44) - -KERN(stride2, 2, nchw44) -KERN(stride2, 3, nchw44) -KERN(stride2, 5, nchw44) -KERN(stride2, 7, nchw44) -#undef KERN -void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter); -void nchw44_pack_src(const int8_t* src, int8_t* dst, int length); - } // namespace conv_bias } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp similarity index 58% rename from dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp rename to dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp index 9b47eefa..0d82c2bf 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp + * \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -13,6 +13,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/elemwise_op.h" #include "src/common/opr_delegate.h" @@ -25,28 +26,19 @@ using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44) static void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2) { auto&& fm = param.filter_meta; - size_t SW = fm.stride[1]; - size_t IH = param.isz[0]; - size_t IW = param.isz[1]; - size_t OH = param.osz[0]; - size_t OW = param.osz[1]; - size_t FH = fm.spatial[0]; - size_t FW = fm.spatial[1]; + int ih = param.isz[0]; + int iw = param.isz[1]; + int ph = fm.padding[0]; + int pw = fm.padding[1]; - OH2 = OH; - OW2 = (OW + 7) & ~7; - IH2 = SW * OH + FH - SW; - IW2 = SW * OW2 + FW - SW; - // Because stride is 2, sometimes IW == IW2+1. Do a max update to - // handle this case. - IH2 = std::max(IH2, IH); - IW2 = std::max(IW2, IW); + ih2 = ih + ph * 2; + iw2 = iw + pw * 2; } static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { constexpr size_t src_expand = 4; @@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { size_t OC = fm.ocpg; size_t FH = fm.spatial[0]; size_t FW = fm.spatial[1]; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(param, IH2, IW2, OH2, OW2); + int IH2, IW2; + get_rectified_size(param, IH2, IW2); if (group == 1) { size_t src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; @@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { - size_t IH = kern_param.isz[0]; - size_t IW = kern_param.isz[1]; - size_t IC = kern_param.filter_meta.icpg; - size_t PH = kern_param.filter_meta.padding[0]; - size_t PW = kern_param.filter_meta.padding[1]; - size_t GROUP = kern_param.filter_meta.group; + int IH = kern_param.isz[0]; + int IW = kern_param.isz[1]; + int IC = kern_param.filter_meta.icpg; + int PH = kern_param.filter_meta.padding[0]; + int PW = kern_param.filter_meta.padding[1]; + int GROUP = kern_param.filter_meta.group; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - size_t padding_group_size = IH2 * IW2 * IC; + int IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + int padding_group_size = IH2 * IW2 * IC; bundle.set(kern_param.workspace_ptr); //! Used for get the workspace offset constexpr int pack_ic = 4; @@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle, size_t group_id = ncb_index.ndrange_id[1]; size_t group_pack_size = 1; - int nr_pad_h = PH * IW2 * pack_ic * expend_element; int nr_pad_w = PW * pack_ic * expend_element; - int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; - int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0 - ? nr_pad_w + over_pad - : (IW2 - IW - PW) * pack_ic * expend_element; - int col_last_pad = - ((int)IH2 - (int)IH - 2 * (int)PH) >= 0 - ? nr_pad_h - : (IH2 - IH - PH) * IW2 * pack_ic * expend_element; + int nr_pad_h = PH * IW2 * pack_ic * expend_element; + int row_last_pad = (IW2 - IW - PW) * pack_ic * expend_element; + int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * expend_element; const int8_t* sptr = static_cast(kern_param.src( batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); @@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, rep(ih_idx, IH) { std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); sptr_base += nr_pad_w; - conv_bias::nchw44_pack_src(sptr, sptr_base, IW); + nchw44_pack_src(sptr, sptr_base, IW); sptr_base += IW * pack_ic * expend_element; sptr += IW * pack_ic; std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); @@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, } } -template +template static void do_conv_kern(WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, @@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle, size_t IC = kern_param.filter_meta.icpg; size_t OC = kern_param.filter_meta.ocpg; size_t GROUP = kern_param.filter_meta.group; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + int IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) - Op op = Op(1.0f, 4.0f); + Op op(1.f, 4.f); if (need_post_process) { float scale_bias = kern_param.bias_type.param().scale; @@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle, const int8_t* fptr = kern_param.filter(group_id) + oc_idx * FH * FW * IC; - void* dst = reinterpret_cast( - reinterpret_cast( - kern_param.dst(batch_id, group_id)) + - oc_idx * OH * OW); + DstType* dst = reinterpret_cast( + kern_param.dst(batch_id, group_id, oc_idx)); const int32_t* bptr = kern_param.bias(batch_id, group_id) + oc_idx; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; - conv_bias::nchw44_pack_filter(fptr, packed_weight, - oc_block / 4 * IC / 4 * FH * FW); -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \ - bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ - static_cast(dst), oc_block, IC, \ - IH2, IW2, OH, OW, op) - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) -#undef KERN1_NCHW44_CONV + nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); + conv_direct_int8_nchw44( + sptr, packed_weight, bptr, nullptr, static_cast(dst), + oc_block, IC, IH2, IW2, OH, OW, op); } -/* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable( +bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( fallback::ConvBiasImpl*, const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const { MEGDNN_MARK_USED_VAR(algo_selection_strategy); auto&& fm = param.filter_meta; - auto FH = fm.spatial[0]; - auto OC = fm.ocpg; - auto IC = fm.icpg; - bool avaible = //! src and filter are qint8, dst is qint8 or qint32 + const int fh = fm.spatial[0]; + const int fw = fm.spatial[1]; + const int oc = fm.ocpg; + const int ic = fm.icpg; + const bool avaible = //! src and filter are qint8, dst is qint8 or qint32 ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.filter_type.enumv() == DTypeEnum::QuantizedS8 && (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && (fm.format == param::Convolution::Format::NCHW44) && - (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && + (oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && + (fh == 2 || fh == 3 || fh == 5 || fh == 7) && param.bias_mode != BiasMode::BIAS; return avaible; } -bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( +bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, const NCBKernSizeParam& param) const { // TODO: benchmark and fix @@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( return false; } -size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace( +size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { return get_bundle(param).total_size_in_bytes(); } SmallVector -ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( +ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { auto fm = param.filter_meta; size_t N = param.n; @@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( WorkspaceBundle wbundle = get_bundle(param); conv_fun do_conv_fun = nullptr; int ow_remain = OW % 8; + bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; // NOTE: remain_w is not used to gen hash of midout for compatible with changing // shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ - midout_iv(#filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ + midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(filter, bias_mode, remain_w) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ + if (need_post_process) { \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + remain_w, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + remain_w, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ + remain_w, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0, "no supported noline mode"); \ + break; \ + } \ + } else { \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ + remain_w, NoneOp) \ + break; \ + default: \ + megdnn_assert( \ + 0, \ + "only support IDENTITY mode when dst is not qint8"); \ + break; \ + } \ } -#define GET_REMAIN_W_PARAM(filter, bias_mode) \ - switch (ow_remain) { \ - case 0: \ - GET_OP_PARAM(filter, bias_mode, 0); \ - break; \ - case 1: \ - GET_OP_PARAM(filter, bias_mode, 1); \ - break; \ - case 2: \ - GET_OP_PARAM(filter, bias_mode, 2); \ - break; \ - case 3: \ - GET_OP_PARAM(filter, bias_mode, 3); \ - break; \ - case 4: \ - GET_OP_PARAM(filter, bias_mode, 4); \ - break; \ - case 5: \ - GET_OP_PARAM(filter, bias_mode, 5); \ - break; \ - case 6: \ - GET_OP_PARAM(filter, bias_mode, 6); \ - break; \ - case 7: \ - GET_OP_PARAM(filter, bias_mode, 7); \ - break; \ - default: \ - megdnn_assert(0); \ +#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ + switch (ow_remain) { \ + case 0: \ + GET_OP_PARAM(stride, filter, bias_mode, 0); \ + break; \ + case 1: \ + GET_OP_PARAM(stride, filter, bias_mode, 1); \ + break; \ + case 2: \ + GET_OP_PARAM(stride, filter, bias_mode, 2); \ + break; \ + case 3: \ + GET_OP_PARAM(stride, filter, bias_mode, 3); \ + break; \ + case 4: \ + GET_OP_PARAM(stride, filter, bias_mode, 4); \ + break; \ + case 5: \ + GET_OP_PARAM(stride, filter, bias_mode, 5); \ + break; \ + case 6: \ + GET_OP_PARAM(stride, filter, bias_mode, 6); \ + break; \ + case 7: \ + GET_OP_PARAM(stride, filter, bias_mode, 7); \ + break; \ + default: \ + megdnn_assert(0); \ } -#define GET_BIAS_MODE_PARAM(filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_REMAIN_W_PARAM(stride, filter, \ + BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } -#define DISPATCH_CONV_KERN() \ +#define DISPATCH_CONV_KERN(stride) \ switch (param.filter_meta.spatial[0]) { \ case 2: \ - GET_BIAS_MODE_PARAM(2) \ + GET_BIAS_MODE_PARAM(stride, 2) \ break; \ case 3: \ - GET_BIAS_MODE_PARAM(3) \ + GET_BIAS_MODE_PARAM(stride, 3) \ break; \ case 5: \ - GET_BIAS_MODE_PARAM(5) \ + GET_BIAS_MODE_PARAM(stride, 5) \ break; \ case 7: \ - GET_BIAS_MODE_PARAM(7) \ + GET_BIAS_MODE_PARAM(stride, 7) \ break; \ default: \ megdnn_assert(0); \ break; \ } - DISPATCH_CONV_KERN(); + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) + .c_str()); + break; + } #undef DO_CONV_KERN_FUN #undef GET_REMAIN_W_PARAM diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h new file mode 100644 index 00000000..e66a50ed --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h @@ -0,0 +1,1428 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/conv_bias/int8/direct.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { + +template +static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); + + c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); + + c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[1][8]; + int8x16_t weight[1][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]); + + c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc); +}; +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*, + int, int, int, const Op&, int) { + megdnn_throw("no impl"); + } +}; +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +//! TODO: can try oh = 2 impl, oc = 8 impl +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 3; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + + c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]); + + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + + c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( + c, op, dst_ptr, ld_dst_oc); + } +}; + +/** +origin weight shape +packed weight shape +example: (format like weight) +origin +<0, 0> <1, 0> <2, 0> <3, 0> +<0, 1> <1, 1> <2, 1> <3, 1> +<0, 2> <1, 2> <2, 2> <3, 2> +<0, 3> <1, 3> <2, 3> <3, 3> +packed +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +**/ +static inline void nchw44_pack_filter(const int8_t* src, int8_t* dst, + int length) { + static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, + 12, 8, 5, 1, 14, 10, 7, 3}; + constexpr int simd_len = 16; + uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); + for (int i = 0; i < length; i++) { + int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); + vst1q_s8(dst + i * simd_len, result); + } +} +/** +origin src shape +packed src shape +example: (format like ) +origin +<0> <0> <0> <0> +packed +low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> +--------------------------------------------------------------------- +high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> +**/ +static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { + static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 3, 2, 1, 0, 3, 2, 1, 0}; + constexpr int pack_ic = 4; + constexpr int simd_len = 16; + uint8x16_t src_idx = vld1q_u8(src_idx_buffer); + for (int i = 0; i < length; i++) { + int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); + vst1q_s8(dst + i * simd_len, result); + } +} + +template +void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, + const size_t ic, const size_t ih, + const size_t iw, const size_t oh, + const size_t ow, const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * oc_step; + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); + } + } + } +} +template +void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, + const size_t ic, const size_t ih, + const size_t iw, const size_t oh, + const size_t ow, const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const int ld_dst_oc = oh * ow * oc_step; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride1Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + KerNeonDirectStride1Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + } + } +} +/////////////////////stride 2///////////////// +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc); +}; +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*, + int, int, int, const Op&, int) { + megdnn_throw("no impl"); + } +}; + +template +static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int32x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[4]; + + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + + c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); + c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); + c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); + c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); + + c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); + c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); + c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); + c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); + c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); + c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); + c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); + c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); + c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); + c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} + +template +static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, + const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[2]; + int8x16_t src[8 + 1]; + int16x8_t temp_c[2]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]); + + src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); + c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); +} +/** +dot like impl. dot 4 ic to 1 oc, accumale to c +example: (format like weight) +packed weight +low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> +--------------------------------------------------------------------- +high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> +dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> +**/ +// TODO: can try oh = 2 impl, oc = 8 impl +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 3; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[3]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8((src_ic_0_3 + 16)); + src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); + src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); + src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); + + src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); + src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); + src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); + src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); + + src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); + src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); + src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); + src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDirectStride2Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, + int iw, const Op& op, int ld_dst_oc) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int oc_step = 4; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + + int32x4_t c[c_dim][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + int16x8_t temp_c[4]; + init_ocx_ow8(c, bias_ptr, oc_step); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); + + // oc == 0 + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + + c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); + c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]); + c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]); + c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); + c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]); + + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); + c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]); + c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]); + c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]); + c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]); + + src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); + c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); + c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]); + c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]); + c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]); + c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]); + + src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); + src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); + c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); + c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]); + c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]); + c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]); + c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +void conv_direct_stride2_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t out_img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oh * ow * oc_step; + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); + } + } + } +} +template +void conv_direct_stride2_int8_nchw44_kern( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const int ld_dst_oc = oh * ow * oc_step; + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride2Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * + pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + KerNeonDirectStride2Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, op, ld_dst_oc); + } + } + } +} +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op); +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + if (filter_size == 2) { + conv_direct_stride1_2x2_int8_nchw44( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } else { + conv_direct_stride1_int8_nchw44_kern( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } + } +}; +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + if (filter_size == 2) { + conv_direct_stride2_2x2_int8_nchw44( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } else { + conv_direct_stride2_int8_nchw44_kern( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); + } + } +}; +template +void conv_direct_int8_nchw44(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, DstType* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + ConvDirectInt8Nchw44Choose::impl(src, filter, bias, temp, dst, oc, + ic, ih, iw, oh, ow, op); +} + +} // namespace +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp deleted file mode 100644 index 6551776d..00000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp +++ /dev/null @@ -1,393 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/int8/algos.h" -#include "src/arm_common/conv_bias/int8/direct.h" -#include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" -#include "src/common/opr_delegate.h" - -#include "midout.h" - -using namespace megdnn; -using namespace arm_common; -using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) - -static void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { - auto&& fm = param.filter_meta; - auto SW = fm.stride[1]; - auto OH = param.osz[0]; - auto OW = param.osz[1]; - auto FH = fm.spatial[0]; - auto FW = fm.spatial[1]; - - OH2 = OH; - OW2 = (OW + 7) & ~7; - IH2 = SW * OH + FH - SW; - IW2 = SW * OW2 + FW - SW; -} - -static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { - constexpr size_t src_expand = 4; - auto&& fm = param.filter_meta; - size_t group = fm.group; - size_t batch = param.n; - size_t IC = fm.icpg; - size_t OC = fm.ocpg; - size_t FH = fm.spatial[0]; - size_t FW = fm.spatial[1]; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(param, IH2, IW2, OH2, OW2); - if (group == 1) { - size_t src_size = - batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; - size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); - return {nullptr, {src_size, weight_size}}; - } else { - size_t src_size = - param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; - size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); - return {nullptr, {src_size, weight_size}}; - } -}; - -static void copy_padding_kern(WorkspaceBundle bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { - size_t IH = kern_param.isz[0]; - size_t IW = kern_param.isz[1]; - size_t IC = kern_param.filter_meta.icpg; - size_t PH = kern_param.filter_meta.padding[0]; - size_t PW = kern_param.filter_meta.padding[1]; - size_t GROUP = kern_param.filter_meta.group; - - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - size_t padding_group_size = IH2 * IW2 * IC; - bundle.set(kern_param.workspace_ptr); - //! Used for get the workspace offset - constexpr int pack_ic = 4; - constexpr int expend_element = 4; - // TODO: block dim is better to get from arg - size_t workspace_ic_block = 4; - size_t workspace_batch_id = workspace_ids[0]; - size_t workspace_group_id = workspace_ids[1]; - size_t workspace_ic_id = workspace_ids[2]; - size_t workspace_ic = workspace_ic_id * workspace_ic_block; - size_t batch_id = ncb_index.ndrange_id[0]; - size_t group_id = ncb_index.ndrange_id[1]; - size_t group_pack_size = 1; - - int nr_pad_h = PH * IW2 * pack_ic * expend_element; - int nr_pad_w = PW * pack_ic * expend_element; - int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; - //! copy to sptr_base to eliminate padding effect - const int8_t* sptr = static_cast(kern_param.src( - batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * IH2 * IW2) * - expend_element; - size_t nr_ic = workspace_ic_block; - if (GROUP > 1) { - nr_ic = IC; - } - rep_step(ic_idx, nr_ic, pack_ic) { - std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); - sptr_base += nr_pad_h; - rep(ih_idx, IH) { - std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); - sptr_base += nr_pad_w; - conv_bias::nchw44_pack_src(sptr, sptr_base, IW); - sptr_base += IW * pack_ic * expend_element; - sptr += IW * pack_ic; - std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t)); - sptr_base += nr_pad_w + over_pad; - } - std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); - sptr_base += nr_pad_h; - } -} - -template -static void do_conv_kern(WorkspaceBundle bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids, - const CpuNDRange& ncb_range) { - size_t OH = kern_param.osz[0]; - size_t OW = kern_param.osz[1]; - size_t FH = kern_param.filter_meta.spatial[0]; - size_t FW = kern_param.filter_meta.spatial[1]; - size_t IC = kern_param.filter_meta.icpg; - size_t OC = kern_param.filter_meta.ocpg; - size_t GROUP = kern_param.filter_meta.group; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; - //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) - Op op = Op(1.0f, 4.0f); - if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = kern_param.dst_type.param().scale; - op = Op(scale_bias, scale_dst); - } - size_t padding_group_size = IH2 * IW2 * IC; - bundle.set(kern_param.workspace_ptr); - - constexpr size_t pack_c = 4; - constexpr size_t src_expand_size = 4; - const size_t workspace_batch_id = workspace_ids[0]; - const size_t workspace_group_id = workspace_ids[1]; - const size_t batch_id = ncb_index.ndrange_id[0]; - const size_t group_id = ncb_index.ndrange_id[1]; - const size_t oc_id = ncb_index.ndrange_id[2]; - const size_t oc_block_num = ncb_range[2]; - size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); - size_t oc_block = nr_pack_per_step * pack_c; - const size_t oc_idx = oc_id * oc_block; - if (oc_id == (oc_block_num - 1)) { - oc_block = OC - oc_id * nr_pack_per_step * pack_c; - } - megdnn_assert(oc_block % pack_c == 0, - "oc must be devisible by 4, but oc = %zu", oc_block); - const int8_t* sptr = - static_cast(bundle.get(0)) + - workspace_batch_id * GROUP * padding_group_size * src_expand_size + - workspace_group_id * padding_group_size * src_expand_size; - - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * FH * FW * IC; - void* dst = reinterpret_cast( - reinterpret_cast( - kern_param.dst(batch_id, group_id)) + - oc_idx * OH * OW); - const int32_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; - auto packed_weight = reinterpret_cast(bundle.get(1)) + - group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; - conv_bias::nchw44_pack_filter(fptr, packed_weight, - oc_block / 4 * IC / 4 * FH * FW); - -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \ - bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ - static_cast(dst), oc_block, IC, \ - IH2, IW2, OH, OW, op) - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) -#undef KERN1_NCHW44_CONV -} - -/* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - MEGDNN_MARK_USED_VAR(algo_selection_strategy); - auto&& fm = param.filter_meta; - auto FH = fm.spatial[0]; - auto OC = fm.ocpg; - auto IC = fm.icpg; - bool avaible = //! src and filter are qint8, dst is qint8 or qint32 - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || - param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && - (fm.format == param::Convolution::Format::NCHW44) && - (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && - param.bias_mode != BiasMode::BIAS; - return avaible; -} - -bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred( - megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, - const NCBKernSizeParam& param) const { - // TODO: benchmark and fix - MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); - MEGDNN_MARK_USED_VAR(param); - return false; -} - -size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { - return get_bundle(param).total_size_in_bytes(); -} - -SmallVector -ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { - auto fm = param.filter_meta; - size_t N = param.n; - size_t IC = fm.icpg; - size_t OC = fm.ocpg; - size_t OW = param.osz[1]; - size_t group = fm.group; - size_t fh = fm.spatial[0]; - size_t fw = fm.spatial[1]; - WorkspaceBundle wbundle = get_bundle(param); - conv_fun do_conv_fun = nullptr; - int ow_remain = OW % 8; -// NOTE: remain_w is not used to gen hash of midout for compatible with changing -// shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ - midout_iv(#filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ - MIDOUT_END(); - -#define GET_OP_PARAM(filter, bias_mode, remain_w) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } - -#define GET_REMAIN_W_PARAM(filter, bias_mode) \ - switch (ow_remain) { \ - case 0: \ - GET_OP_PARAM(filter, bias_mode, 0); \ - break; \ - case 1: \ - GET_OP_PARAM(filter, bias_mode, 1); \ - break; \ - case 2: \ - GET_OP_PARAM(filter, bias_mode, 2); \ - break; \ - case 3: \ - GET_OP_PARAM(filter, bias_mode, 3); \ - break; \ - case 4: \ - GET_OP_PARAM(filter, bias_mode, 4); \ - break; \ - case 5: \ - GET_OP_PARAM(filter, bias_mode, 5); \ - break; \ - case 6: \ - GET_OP_PARAM(filter, bias_mode, 6); \ - break; \ - case 7: \ - GET_OP_PARAM(filter, bias_mode, 7); \ - break; \ - default: \ - megdnn_assert(0); \ - } - -#define GET_BIAS_MODE_PARAM(filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } - -#define DISPATCH_CONV_KERN() \ - switch (param.filter_meta.spatial[0]) { \ - case 2: \ - GET_BIAS_MODE_PARAM(2) \ - break; \ - case 3: \ - GET_BIAS_MODE_PARAM(3) \ - break; \ - case 5: \ - GET_BIAS_MODE_PARAM(5) \ - break; \ - case 7: \ - GET_BIAS_MODE_PARAM(7) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } - - DISPATCH_CONV_KERN(); - -#undef DO_CONV_KERN_FUN -#undef GET_REMAIN_W_PARAM -#undef GET_OP_PARAM -#undef GET_BIAS_MODE_PARAM -#undef DISPATCH_CONV_KERN - - megdnn_assert(do_conv_fun); - - SmallVector ret_kerns; - WorkspaceBundle bundle = wbundle; - - constexpr size_t pack_oc = 4; - size_t oc_step = pack_oc; - if (fh == 2 && fw == 2 && OC >= 8) { - oc_step = 8; - } - - if (group == 1) { - CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); - }; - constexpr size_t pack_ic = 4; - ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); - auto do_conv = [bundle, do_conv_fun, ncb_range]( - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); - }; - ret_kerns.push_back({do_conv, ncb_range}); - } else { - CpuNDRange ncb_range = {N, group, 1}; - auto do_conv = [bundle, do_conv_fun, ncb_range]( - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { - copy_padding_kern(bundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}); - do_conv_fun(bundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}, ncb_range); - }; - ret_kerns.push_back({do_conv, ncb_range}); - } - - return ret_kerns; -} - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp deleted file mode 100644 index ca5c45a6..00000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp +++ /dev/null @@ -1,791 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/int8/direct.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" -#include "src/fallback/conv_bias/common.h" - -using namespace megdnn; -using namespace arm_common; -namespace { - -/** -dot like impl. dot 4 ic to 1 oc, accumale to c -example: (format like weight) -packed weight -low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> ---------------------------------------------------------------------- -high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> -dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> -**/ -// TODO: can try oh = 2 impl, oc = 8 impl -template -static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[3]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); - - c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); - - c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc4 = oc_step * fh * fw * ic; - - int32x4_t c[2][8]; - int8x16_t weight[2][2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[4]; - - init_oc8_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0][0] = vld1q_s8(read_weight_ptr); - weight[0][1] = vld1q_s8(read_weight_ptr + 16); - weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); - weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); - - c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); - - c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); - - c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); - - c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_oc8_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); -} - -template -static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); - - c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); - - c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[5]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); - - c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - - c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[7]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); - weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - - c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]); - - src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - - c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -} // namespace - -/** -origin weight shape -packed weight shape -example: (format like weight) -origin -<0, 0> <1, 0> <2, 0> <3, 0> -<0, 1> <1, 1> <2, 1> <3, 1> -<0, 2> <1, 2> <2, 2> <3, 2> -<0, 3> <1, 3> <2, 3> <3, 3> -packed -low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> ---------------------------------------------------------------------- -high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> -**/ -void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { - static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, - 12, 8, 5, 1, 14, 10, 7, 3}; - constexpr int simd_len = 16; - uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); - for (int i = 0; i < length; i++) { - int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); - vst1q_s8(dst + i * simd_len, result); - } -} -/** -origin src shape -packed src shape -example: (format like ) -origin -<0> <0> <0> <0> -packed -low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> ---------------------------------------------------------------------- -high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> -**/ -void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { - static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, - 3, 2, 1, 0, 3, 2, 1, 0}; - constexpr int pack_ic = 4; - constexpr int simd_len = 16; - uint8x16_t src_idx = vld1q_u8(src_idx_buffer); - for (int i = 0; i < length; i++) { - int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); - vst1q_s8(dst + i * simd_len, result); - } -} - -template -void conv_bias::conv_direct_stride1_2x2_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 2; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t big_oc_step = 8; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_oc = oh * ow * ic_step; - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - } - } - if (oc_remain > 0) { - const size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} -template -void conv_bias::conv_direct_stride1_3x3_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 3; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_3x3s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_3x3s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} -template -void conv_bias::conv_direct_stride1_5x5_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 5; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_5x5s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_5x5s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} - -template -void conv_bias::conv_direct_stride1_7x7_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 7; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_7x7s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * iw + ow_end) * ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_7x7s1_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} - -#define INSTANTIATION(stride, i, bias, remain_w, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ - bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ - int32_t*, int8_t*, const size_t, const size_t, \ - const size_t, const size_t, const size_t, \ - const size_t, const Op&); - -#define FOR_OP(stride, i, bias, remain_w) \ - INSTANTIATION(stride, i, bias, remain_w, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, bias, remain_w, \ - ReluOp) \ - INSTANTIATION(stride, i, bias, remain_w, \ - HSwishOp) - -#define FOR_REMAIN(stride, i, bias) \ - FOR_OP(stride, i, bias, 0) \ - FOR_OP(stride, i, bias, 1) \ - FOR_OP(stride, i, bias, 2) \ - FOR_OP(stride, i, bias, 3) \ - FOR_OP(stride, i, bias, 4) \ - FOR_OP(stride, i, bias, 5) \ - FOR_OP(stride, i, bias, 6) \ - FOR_OP(stride, i, bias, 7) - -#define FOR_BIAS(stride, i) \ - FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ - FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride1) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp deleted file mode 100644 index 06d047c3..00000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp +++ /dev/null @@ -1,793 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/int8/direct.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" -#include "src/fallback/conv_bias/common.h" - -using namespace megdnn; -using namespace arm_common; -namespace { - -/** -dot like impl. dot 4 ic to 1 oc, accumale to c -example: (format like weight) -packed weight -low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> ---------------------------------------------------------------------- -high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> -dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> -**/ -// TODO: can try oh = 2 impl, oc = 8 impl -template -static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[3]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); - c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); - - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); - c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc4 = oc_step * fh * fw * ic; - - int32x4_t c[2][8]; - int8x16_t weight[2][2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[4]; - - init_oc8_ow8(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8(src_ic_0_3 + 16); - src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); - src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); - src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0][0] = vld1q_s8(read_weight_ptr); - weight[0][1] = vld1q_s8(read_weight_ptr + 16); - weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); - weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); - - c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); - c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); - c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); - c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); - c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); - - c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); - c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); - c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); - c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); - c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); - - src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); - c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); - c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); - c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); - c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); - c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); - - src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); - c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); - c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); - c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); - c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); - c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - store_oc8_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); -} - -template -static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[2]; - int8x16_t src[8 + 1]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); - - c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); - c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]); - - src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); - c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[5]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8((src_ic_0_3 + 16)); - src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); - src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); - src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); - - src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); - src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); - src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); - src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); - c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); - - src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); - src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); - src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); - src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); - c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -template -static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - int8_t* dst_ptr, int ic, int ih, - int iw, const Op& op) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 4; - constexpr int loop_ic_step = 4; - constexpr int ld_weight_ic4 = 16; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - - int32x4_t c[2 * 4]; - int8x16_t weight[7]; - int8x16_t src[8 + 2]; - int16x8_t temp_c[2]; - init_oc4_ow8(c, bias_ptr); - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - - src[0] = vld1q_s8(src_ic_0_3); - src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); - src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); - src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); - src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); - - // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; - - weight[0] = vld1q_s8(read_weight_ptr); - weight[1] = vld1q_s8(read_weight_ptr + 16); - weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); - weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); - weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); - weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); - weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); - - c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]); - c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); - c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]); - - src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); - src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); - src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); - c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]); - c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]); - c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]); - - src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); - src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); - src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); - src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); - c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]); - c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]); - c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]); - - src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); - src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); - src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); - src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); - c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]); - c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]); - c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]); - } - weight_ptr += fh * fw * ld_weight_ic4; - } - - store_oc4_ow8_remain_static(c, op, dst_ptr); -} - -} // namespace - -template -void conv_bias::conv_direct_stride2_2x2_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 2; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t big_oc_step = 8; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t out_img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_oc = oh * ow * ic_step; - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ld_oc, op); - } - } - } - if (oc_remain > 0) { - const size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} -template -void conv_bias::conv_direct_stride2_3x3_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 3; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_3x3s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_3x3s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} -template -void conv_bias::conv_direct_stride2_5x5_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 5; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_5x5s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_5x5s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} - -template -void conv_bias::conv_direct_stride2_7x7_int8_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t filter_size = 7; - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 4; - constexpr size_t oc_step = 4; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 8; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_7x7s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - ker_neon_dirctconv_7x7s2_oc4_ow8( - src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, op); - } - } - } -} - -#define INSTANTIATION(stride, i, bias, remain_w, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ - bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ - int32_t*, int8_t*, const size_t, const size_t, \ - const size_t, const size_t, const size_t, \ - const size_t, const Op&); - -#define FOR_OP(stride, i, bias, remain_w) \ - INSTANTIATION(stride, i, bias, remain_w, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, bias, remain_w, \ - ReluOp) \ - INSTANTIATION(stride, i, bias, remain_w, \ - HSwishOp) - -#define FOR_REMAIN(stride, i, bias) \ - FOR_OP(stride, i, bias, 0) \ - FOR_OP(stride, i, bias, 1) \ - FOR_OP(stride, i, bias, 2) \ - FOR_OP(stride, i, bias, 3) \ - FOR_OP(stride, i, bias, 4) \ - FOR_OP(stride, i, bias, 5) \ - FOR_OP(stride, i, bias, 6) \ - FOR_OP(stride, i, bias, 7) - -#define FOR_BIAS(stride, i) \ - FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ - FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index a813fce8..2c88b400 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; - AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; + AlgoS8DirectNCHW44 s8_direct_nchw44; AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; - AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; @@ -114,11 +113,10 @@ public: direct_algos.emplace_back(&qu8_direct_stride1_small_group); direct_algos.emplace_back(&s8_direct_stride2_large_group); direct_algos.emplace_back(&s8_direct_stride2_small_group); - direct_algos.emplace_back(&s8_direct_stride2_nchw44); + direct_algos.emplace_back(&s8_direct_nchw44); direct_algos.emplace_back(&s8_direct_nchw_nchw44); direct_algos.emplace_back(&s8_direct_stride1_large_group); direct_algos.emplace_back(&s8_direct_stride1_small_group); - direct_algos.emplace_back(&s8_direct_stride1_nchw44); direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 99f6f51e..50b73a4b 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -37,9 +37,8 @@ protected: private: class AlgoS8DirectStride1; - class AlgoS8DirectStride1NCHW44; class AlgoS8DirectStride2; - class AlgoS8DirectStride2NCHW44; + class AlgoS8DirectNCHW44; class AlgoS8DirectNCHWNCHW44; class AlgoQU8DirectStride1; class AlgoQU8DirectStride2; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/none.h b/dnn/src/arm_common/elemwise_helper/kimpl/none.h index 224148eb..6c3aa479 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/none.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/none.h @@ -27,6 +27,8 @@ struct NoneOp; #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ template <> \ struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ + NoneOp(){}; \ + NoneOp(float, float){}; \ using NoneOpBase::NoneOpBase; \ using NoneOpBase::operator(); \ constexpr static size_t SIMD_WIDTH = _simd_width; \ diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index f357319d..50ba1f65 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name, run(1, 3, 32, 224, 224, 5, 1, true); run(1, 3, 64, 224, 224, 7, 1, true); - for (size_t stride : {1, 2}) { + run(1, 64, 128, 56, 56, 3, 2, false); + run(1, 128, 256, 28, 28, 3, 2, false); + run(1, 256, 512, 14, 14, 3, 2, false); + + run(1, 128, 128, 28, 28, 3, 1, false); + run(1, 256, 256, 14, 14, 3, 1, false); + run(1, 512, 512, 7, 7, 3, 1, false); + + for (size_t stride : {1}) { printf("stride %zu\n", stride); for (size_t filter_size : {2, 3, 5, 7}) { for (size_t img_size : {32}) { diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index acb05e59..7e284f80 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { checker_conv_bias_qint8x8x8( get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), - handle(), "S8_NCHW44_DIRECT_STRD1"); + handle(), "S8_NCHW44_DIRECT"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { + checker_conv_bias_qint8x8x32( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), + handle(), "S8_NCHW44_DIRECT"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { + checker_conv_bias_qint8x8x32( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), + handle(), "S8_NCHW44_DIRECT"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { checker_conv_bias_qint8x8x8( get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), - handle(), "S8_NCHW44_DIRECT_STRD2"); + handle(), "S8_NCHW44_DIRECT"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { checker_conv_bias_qint8x8x8( @@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) { dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); } - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { using namespace conv_bias; @@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { param::MatrixMul::Format format, float eps) { for (auto&& arg : args) { for (uint32_t m : out_size) { - checker.set_extra_opr_impl(std::bind( - winograd_algo_extra_impl, std::placeholders::_1, m, - arg.param, handle, format)); - checker.set_dtype(0, A_dtype) - .set_dtype(1, B_dtype) - .set_dtype(2, C_dtype) - .set_dtype(4, D_dtype) - .set_epsilon(eps) - .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); - } + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } } }; @@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); - std::vector quantized_args = get_int8_nchw44_args (3,4); + std::vector quantized_args = get_int8_nchw44_args(3, 4); UniformIntRNG int_rng{-50, 50}; checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), @@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); } - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { using namespace conv_bias; Checker checker(handle()); @@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM param::MatrixMul::Format format, float eps) { for (auto&& arg : args) { for (uint32_t m : out_size) { - checker.set_extra_opr_impl(std::bind( - winograd_algo_extra_impl, std::placeholders::_1, m, - arg.param, handle, format)); - checker.set_dtype(0, A_dtype) - .set_dtype(1, B_dtype) - .set_dtype(2, C_dtype) - .set_dtype(4, D_dtype) - .set_epsilon(eps) - .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); - } + checker.set_extra_opr_impl(std::bind( + winograd_algo_extra_impl, std::placeholders::_1, m, + arg.param, handle, format)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } } }; @@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { using namespace conv_bias; Checker checker(handle()); @@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F #if MEGDNN_AARCH64 const char* matmul_name = "AARCH64_F32_MK4_4x16"; #else - const char* matmul_name = "ARMV7_F32_MK4_4x8"; + const char* matmul_name = "ARMV7_F32_MK4_4x8"; #endif checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); - std::vector quantized_args = - get_int8_nchw44_args(3, 4, true); + std::vector quantized_args = get_int8_nchw44_args(3, 4, true); UniformIntRNG int_rng{-50, 50}; checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f), - dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); + dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, + epsilon); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { using namespace conv_bias; Checker checker(handle()); @@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F #if MEGDNN_AARCH64 const char* matmul_name = "AARCH64_F32_MK4_4x16"; #else - const char* matmul_name = "ARMV7_F32_MK4_4x8"; + const char* matmul_name = "ARMV7_F32_MK4_4x8"; #endif checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); @@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), dtype::QuantizedS8(0.01887994f), dtype::QuantizedS32(0.41113496f * 0.01887994f), - dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); + dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, + epsilon); } - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { using namespace conv_bias; -- GitLab