diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index c1e3711663e9578f7f36be42f2b685023a425795..26e9c5d14a300eff592d5ffbc91d2fb83aa657e7 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -38,6 +38,18 @@ public: const NCBKernSizeParam& param) const override; }; +class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { +public: + AlgoS8x8x16DirectNCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } + bool usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; +}; + class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..302aebe1f8bc6e6db19d377303ae3726e772b091 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp @@ -0,0 +1,481 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct_int8x8x16_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/algos.h" +#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct) + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2) { + auto&& fm = param.filter_meta; + int ih = param.isz[0]; + int iw = param.isz[1]; + int ph = fm.padding[0]; + int pw = fm.padding[1]; + + ih2 = ih + ph * 2; + iw2 = iw + pw * 2; +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + size_t group = fm.group; + size_t batch = param.n; + size_t IC = fm.icpg; + int IH2, IW2; + get_rectified_size(param, IH2, IW2); + + if (group == 1) { + size_t src_size = 0; + bool need_padding = param.filter_meta.padding[0] > 0 || + param.filter_meta.padding[1] > 0; + src_size = need_padding + ? batch * group * IC * IH2 * IW2 * sizeof(int8_t) + : 0; +#if MEGDNN_ARMV7 + if (fm.stride[0] == 1) { + constexpr int src_expand_element = 4; + src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) * + src_expand_element; + } +#endif + return {nullptr, {src_size}}; + } else { + size_t src_size = 0; + bool need_padding = param.filter_meta.padding[0] > 0 || + param.filter_meta.padding[1] > 0; + src_size = need_padding + ? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) + : 0; +#if MEGDNN_ARMV7 + if (fm.stride[0] == 1) { + constexpr int src_expand_element = 4; + src_size = param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * + src_expand_element; + } +#endif + return {nullptr, {src_size}}; + } +}; + +#if MEGDNN_ARMV7 +static void copy_padding_kern(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + int IH = kern_param.isz[0]; + int IW = kern_param.isz[1]; + int IC = kern_param.filter_meta.icpg; + int PH = kern_param.filter_meta.padding[0]; + int PW = kern_param.filter_meta.padding[1]; + int GROUP = kern_param.filter_meta.group; + int IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + int padding_group_size = IH2 * IW2 * IC; + //! Used for get the workspace offset + constexpr int pack_ic = 4; + constexpr int src_expand_element = 4;; + size_t workspace_ic_block = 4; + size_t workspace_batch_id = workspace_ids[0]; + size_t workspace_group_id = workspace_ids[1]; + size_t workspace_ic_id = workspace_ids[2]; + size_t workspace_ic = workspace_ic_id * workspace_ic_block; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + size_t group_pack_size = 1; + + int nr_pad_w = PW * pack_ic * src_expand_element; + int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; + int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; + int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; + const int8_t* sptr = static_cast(kern_param.src( + batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); + + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * IH2 * IW2) * + src_expand_element; + size_t nr_ic = workspace_ic_block; + if (GROUP > 1) { + nr_ic = IC; + } + rep_step(ic_idx, nr_ic, pack_ic) { + std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); + sptr_base += nr_pad_h; + rep(ih_idx, IH) { + std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); + sptr_base += nr_pad_w; + int8x8x16_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW); + sptr_base += IW * pack_ic * src_expand_element; + sptr += IW * pack_ic; + std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); + sptr_base += row_last_pad; + } + std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); + sptr_base += col_last_pad; + } +} +#endif + +static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + int IH = kern_param.isz[0]; + int IW = kern_param.isz[1]; + int IC = kern_param.filter_meta.icpg; + int PH = kern_param.filter_meta.padding[0]; + int PW = kern_param.filter_meta.padding[1]; + int GROUP = kern_param.filter_meta.group; + int IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + int padding_group_size = IH2 * IW2 * IC; + //! Used for get the workspace offset + constexpr int pack_ic = 4; + constexpr int src_expand_element = 1; + size_t workspace_ic_block = 4; + size_t workspace_batch_id = workspace_ids[0]; + size_t workspace_group_id = workspace_ids[1]; + size_t workspace_ic_id = workspace_ids[2]; + size_t workspace_ic = workspace_ic_id * workspace_ic_block; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + size_t group_pack_size = 1; + + int nr_pad_w = PW * pack_ic * src_expand_element; + int nr_pad_h = PH * IW2 * pack_ic * src_expand_element; + int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element; + int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element; + const int8_t* sptr = static_cast(kern_param.src( + batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); + + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * IH2 * IW2) * + src_expand_element; + size_t nr_ic = workspace_ic_block; + if (GROUP > 1) { + nr_ic = IC; + } + rep_step(ic_idx, nr_ic, pack_ic) { + std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); + sptr_base += nr_pad_h; + rep(ih_idx, IH) { + std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); + sptr_base += nr_pad_w; + std::memcpy(sptr_base, sptr, IW * pack_ic); + sptr_base += IW * pack_ic * src_expand_element; + sptr += IW * pack_ic; + std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); + sptr_base += row_last_pad; + } + std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t)); + sptr_base += col_last_pad; + } +} + +template +static void do_conv_kern(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t FH = kern_param.filter_meta.spatial[0]; + size_t FW = kern_param.filter_meta.spatial[1]; + size_t IC = kern_param.filter_meta.icpg; + size_t OC = kern_param.filter_meta.ocpg; + size_t GROUP = kern_param.filter_meta.group; + int IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + size_t padding_group_size = IH2 * IW2 * IC; + + constexpr size_t pack_c = 4; + const size_t workspace_batch_id = workspace_ids[0]; + const size_t workspace_group_id = workspace_ids[1]; + const size_t batch_id = ncb_index.ndrange_id[0]; + const size_t group_id = ncb_index.ndrange_id[1]; + const size_t oc_id = ncb_index.ndrange_id[2]; + const size_t oc_block_num = ncb_range[2]; + + megdnn_assert((OC & (pack_c - 1)) == 0, "OC must times of 4"); + size_t nr_pack_per_step = div_ceil(OC / pack_c, oc_block_num); + size_t oc_block = nr_pack_per_step * pack_c; + const size_t oc_idx = oc_id * oc_block; + if (oc_id == (oc_block_num - 1)) { + oc_block = OC - oc_id * nr_pack_per_step * pack_c; + } + megdnn_assert(oc_block % pack_c == 0, + "oc must be devisible by 4, but oc = %zu", oc_block); + + bool need_padding = kern_param.filter_meta.padding[0] > 0 || + kern_param.filter_meta.padding[1] > 0; + const int8_t* sptr = need_padding + ? static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + : kern_param.src(batch_id, group_id); + //!armv7 use packsrc mode +#if MEGDNN_ARMV7 + if (stride == 1) { + constexpr size_t src_expand_size = 4; + sptr = static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size * + src_expand_size + + workspace_group_id * padding_group_size * src_expand_size; + } +#endif + + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * FH * FW * IC; + int16_t* dst = reinterpret_cast( + kern_param.dst(batch_id, group_id, oc_idx)); + const int16_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose< + bias_mode, filter, stride>::impl(sptr, fptr, bptr, dst, oc_block, + IC, IH2, IW2, OH, OW); +} + +bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MEGDNN_MARK_USED_VAR(algo_selection_strategy); + auto&& fm = param.filter_meta; + const int fh = fm.spatial[0]; + const int fw = fm.spatial[1]; + const int oc = fm.ocpg; + const int ic = fm.icpg; + const bool avaible = //! src and filter are int8, dst is int16_t + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16) && + (fm.format == param::Convolution::Format::NCHW44) && + (oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && + (fh == 2 || fh == 3 || fh == 5 || fh == 7) && + param.nonlineMode == NonlineMode::IDENTITY && + param.bias_mode != BiasMode::BIAS; + return avaible; +} + +size_t ConvBiasImpl::AlgoS8x8x16DirectNCHW44::get_workspace( + const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t IC = fm.icpg; + size_t OC = fm.ocpg; + size_t group = fm.group; + size_t fh = fm.spatial[0]; + size_t fw = fm.spatial[1]; + size_t ph = fm.padding[0]; + size_t pw = fm.padding[1]; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \ + midout_iv("int8x8x16_nchw44_direct_" \ + "conv" #stride #filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \ + break; \ + default: \ + megdnn_throw(ssprintf("only support IDENTITY mode when dst is " \ + "dt_int16 nonlineMode is %d", \ + uint32_t(param.nonlineMode)) \ + .c_str()); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_throw(ssprintf("only support NO_BIAS/BROADCAST biasmode " \ + "when dst is " \ + "dt_int16 biasmode is %d", \ + uint32_t(param.bias_mode)) \ + .c_str()); \ + break; \ + } + +#define DISPATCH_CONV_KERN(stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(stride, 5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(stride, 7) \ + break; \ + default: \ + megdnn_throw(ssprintf("only support 2x2 3x3 5x5 7x7 filters size " \ + "when dst is " \ + "dt_int16 filter size is %u", \ + uint32_t(param.filter_meta.spatial[0])) \ + .c_str()); \ + break; \ + } + + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_throw(ssprintf("Unsupport stride size %u for the 8x8x16 direct conv", + param.filter_meta.stride[0]) + .c_str()); + break; + } + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + + constexpr size_t pack_oc = 4; + size_t oc_step = pack_oc; + if (fh == fw && (fh == 2 || fw == 3) && OC >= 8) { + oc_step = 8; + } + +#if MEGDNN_ARMV7 + if (param.filter_meta.stride[0] == 1) { + if (group == 1) { + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto copy_padding = [wbundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + copy_padding_kern(wbundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + constexpr size_t pack_ic = 4; + ret_kerns.push_back( + {copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + auto do_conv = [wbundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + do_conv_fun(wbundle, kern_param, ncb_index, + ncb_index.ndrange_id, ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } else { + CpuNDRange ncb_range = {N, group, 1}; + auto do_conv = [wbundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + copy_padding_kern(wbundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}); + do_conv_fun(wbundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}, ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } + return ret_kerns; + } +#endif + + bool need_padding = ph > 0 || pw >0; + + if (group == 1) { + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto copy_padding = [wbundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, + ncb_index.ndrange_id); + }; + constexpr size_t pack_ic = 4; + if (need_padding) { + ret_kerns.push_back( + {copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + } + auto do_conv = [wbundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } else { + CpuNDRange ncb_range = {N, group, 1}; + auto do_conv = [wbundle, do_conv_fun, ncb_range, need_padding]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + if (need_padding) { + copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}); + }; + do_conv_fun(wbundle, kern_param, ncb_index, + {0, ncb_index.thread_id, 0}, ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + } + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h new file mode 100644 index 0000000000000000000000000000000000000000..34272f84a535c2867c8d8c5ff049aa5651a984ae --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace int8x8x16_direct_nchw44 { + +/** +origin src shape +packed src shape +example: (format like ) +origin +<0> <1> <2> <3> +packed +low 64 bit <0> <0> <0> <0> | <1> <1> <1> <1> +--------------------------------------------------------------------- +high 64 bit <2> <2> <2> <2> | <3> <3> <3> <3> +**/ +static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { + static const uint8_t src_idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + constexpr int pack_ic = 4; + constexpr int simd_len = 16; + uint8x16_t src_idx = vld1q_u8(src_idx_buffer); + for (int i = 0; i < length; i++) { + int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); + vst1q_s8(dst + i * simd_len, result); + } +} + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow); +}; + +} // namespace int8_direct_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ac346294e9d59cfd1695b20bf24f1c95cd5784a --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp @@ -0,0 +1,971 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_aarch64.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#if MEGDNN_AARCH64 +#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { + +#define INIT_SUM() \ + int16x4_t init_sum; \ + if (bias_mode == BiasMode::NO_BIAS) { \ + init_sum = vdup_n_s16(0); \ + } else { \ + init_sum = vld1_s16(bias_ptr); \ + } + +#define STORE_1_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 1 error remainw"); \ + }; + +#define STORE_2_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ + } + +#define STORE_1_LINE_RESULT_OW4_OH2() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 8 + ow, vcombine_s16(c[0][6], c[0][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ow, c[0][4]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + ow + 8, c[0][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ + } + +#define STORE_1_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 1 error remainw"); \ + }; + +template +static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w,int ld_dst_oc) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int16x4_t c[2][4]; + int8x16_t weight[2][2]; + int8x16_t src[5]; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + INIT_SUM(); +#define cb(_i) \ + c[0][_i] = init_sum; \ + c[1][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); + src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); + src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); + + weight[0][0] = vld1q_s8(weight_ptr); + weight[0][1] = vld1q_s8(weight_ptr + 16); + weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); + +#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + int16x8_t tmp0; + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); + + src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); + src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); + + CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); + CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); + + src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); + src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); + src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); + + weight[0][0] = vld1q_s8(weight_ptr + 32); + weight[0][1] = vld1q_s8(weight_ptr + 48); + + weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); + weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); + + src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); + src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); + + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); + + CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); + CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]); + + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_2_LINE_RESULT_OW4(); +} + +template +static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w, + int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + const int ic_stride = ih * iw; + + int16x4_t c[1][4]; + int8x16_t weight[1][2]; + int8x16_t src[5]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + int16x8_t tmp0; + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][0]); + + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], + c[0][1]); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][2]); + + CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], + c[0][3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT_OW4(); +} +#undef CALC_ONE_RESULT + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + +template +static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w, + int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + const int ic_stride = ih * iw; + int16x4_t c[1][4]; + int8x16_t weight[1][3]; + int8x16_t src[6]; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); + src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); + src[2] = vld_dup_tbl_s32(src_row0 + 8, idx); + + weight[0][0] = vld1q_s8(weight_ptr); + weight[0][1] = vld1q_s8(weight_ptr + 16); + weight[0][2] = vld1q_s8(weight_ptr + 32); + + src[3] = vld_dup_tbl_s32(src_row0 + 12, idx); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); + + src[4] = vld_dup_tbl_s32(src_row0 + 16, idx); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); + + src[5] = vld_dup_tbl_s32(src_row0 + 20, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); + + CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); + + src[0] = vld_dup_tbl_s32(src_row1 + 0, idx); + src[1] = vld_dup_tbl_s32(src_row1 + 4, idx); + src[2] = vld_dup_tbl_s32(src_row1 + 8, idx); + + weight[0][0] = vld1q_s8(weight_ptr + 48); + weight[0][1] = vld1q_s8(weight_ptr + 64); + weight[0][2] = vld1q_s8(weight_ptr + 80); + + src[3] = vld_dup_tbl_s32(src_row1 + 12, idx); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); + + src[4] = vld_dup_tbl_s32(src_row1 + 16, idx); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); + + src[5] = vld_dup_tbl_s32(src_row1 + 20, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); + + CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); + + src[0] = vld_dup_tbl_s32(src_row2 + 0, idx); + src[1] = vld_dup_tbl_s32(src_row2 + 4, idx); + src[2] = vld_dup_tbl_s32(src_row2 + 8, idx); + + weight[0][0] = vld1q_s8(weight_ptr + 96); + weight[0][1] = vld1q_s8(weight_ptr + 112); + weight[0][2] = vld1q_s8(weight_ptr + 128); + + src[3] = vld_dup_tbl_s32(src_row2 + 12, idx); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); + + src[4] = vld_dup_tbl_s32(src_row2 + 16, idx); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]); + + src[5] = vld_dup_tbl_s32(src_row2 + 20, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]); + CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]); + + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT_OW4(); +} + +template +static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, + int ih, int iw, int remain_w, + int /*ld_dst_oc*/, int ow) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + const int ic_stride = ih * iw; + int16x4_t c[1][8]; + int8x16_t weight[2][3]; + int8x16_t src[1][6]; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = + src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; +#define LOAD_SRC(_src, _src_ptr) \ + _src[0] = vld_dup_tbl_s32(_src_ptr + 0, idx); \ + _src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ + _src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ + _src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ + _src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \ + _src[5] = vld_dup_tbl_s32(_src_ptr + 20, idx); + + LOAD_SRC(src[0], src_row0); + + weight[0][0] = vld1q_s8(weight_ptr); + weight[0][1] = vld1q_s8(weight_ptr + 16); + weight[0][2] = vld1q_s8(weight_ptr + 32); + + weight[1][0] = vld1q_s8(weight_ptr + 48); + weight[1][1] = vld1q_s8(weight_ptr + 64); + weight[1][2] = vld1q_s8(weight_ptr + 80); + + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], + c[0][0]); // row0 src0 w0 + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); + + LOAD_SRC(src[0], src_row1); + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], + c[0][4]); // row1 src1 w0 + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); + + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], + c[0][0]); // row1 src1 w1 + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][1]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][2]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][3]); + + LOAD_SRC(src[0], src_row2); + + weight[0][0] = vld1q_s8(weight_ptr + 96); + weight[0][1] = vld1q_s8(weight_ptr + 112); + weight[0][2] = vld1q_s8(weight_ptr + 128); + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], + c[0][4]); // row2 src0 w1 + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][5]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][6]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][7]); + + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], + c[0][0]); // row2 w0 src[0] + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); + + LOAD_SRC(src[0], src_row3); + + CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], + c[0][4]); // row3 w0 src1 + CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); + CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); + CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); + + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT_OW4_OH2(); +} +#undef LOAD_SRC +#undef CALC_ONE_RESULT + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w, int ld_dst_oc); +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w, int /*ld_dst_oc*/) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + int16x4_t c[1][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); + src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); + src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); + src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); + src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); + src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ + _w4, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][0]); + CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][2]); + CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][3]); + CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][4]); + CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][5]); + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); + CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][6]); + CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + STORE_1_LINE_RESULT(); + } +}; +#undef CALC_ONE_RESULT +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int remain_w, int /*ld_dst_oc*/) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + int16x4_t c[1][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx); + src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx); + src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx); + src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx); + src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx); + src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx); + + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ + _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \ + int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + tmp2 = vaddq_s16(tmp2, tmp3); \ + tmp0 = vaddq_s16(tmp0, tmp2); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], weight, c[0][0]); + CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], + src[7], weight, c[0][1]); + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); + + CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], + src[8], weight, c[0][2]); + CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], + src[9], weight, c[0][3]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 12 * 4, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 13 * 4, idx); + CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], + src[0], weight, c[0][4]); + CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], + src[1], weight, c[0][5]); + CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], + src[2], weight, c[0][6]); + CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], + src[3], weight, c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT(); + } +}; + +#undef CALC_ONE_RESULT +template +void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, + const int8_t* filter, + const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, + const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * oc_step; + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + size_t oh_idx = 0; + for (; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_step, ld_oc); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_step, ld_oc); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); + } + } + } +} + +template +void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, + int16_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow) { + constexpr size_t filter_size = 3; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const int ld_oc = oh * ow * oc_step; + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + size_t oh_idx = 0; + for (; oh_idx + 1 < oh; oh_idx += 2) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow4_oh2( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_step, ld_oc, + ow * oc_step); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow4_oh2( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_remain, ld_oc, + ow * oc_step); + } + } + for (; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_step, ld_oc); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + ker_neon_dirctconv_3x3s1_oc4_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_remain, ld_oc); + } + } + } +} + +template +void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, + const int8_t* filter, + const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, + const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + + const size_t img_stride = oh * ow; + const int ld_dst_oc = oh * ow * oc_step; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride1Int8::impl( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_step, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + KerNeonDirectStride1Int8::impl( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ow_remain, ld_dst_oc); + } + } + } +} +} // namespace + +namespace int8x8x16_direct_nchw44 { +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_int8_nchw44_kern( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_2x2_int8_nchw44(src, filter, bias, dst, + oc, ic, ih, iw, oh, ow); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; + +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ + template struct ConvDirectInt8Nchw44Choose; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(1); + +} // namespace int8x8x16_direct_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d034208f04c791ee2b0e5bdb35135537e426cb5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp @@ -0,0 +1,854 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_armv7.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/common/utils.h" +#if MEGDNN_ARMV7 +#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { + +#define INIT_SUM() \ + int16x4_t init_sum; \ + if (bias_mode == BiasMode::NO_BIAS) { \ + init_sum = vdup_n_s16(0); \ + } else { \ + init_sum = vld1_s16(bias_ptr); \ + } + +#define STORE_1_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 1 error remainw"); \ + }; + +#define STORE_2_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 24, \ + vcombine_s16(c[1][6], c[1][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ + } + +template +static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int src_expand_size = 4; + const int ic_stride = ih * iw * src_expand_size; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + int16x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[4]; + + INIT_SUM(); +#define cb(_i) \ + c[0][_i] = init_sum; \ + c[1][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + + 0 * iw * ic_step * src_expand_size; + const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + + 1 * iw * ic_step * src_expand_size; + + src[0] = vld1q_s8(src_row0); + src[1] = vld1q_s8(src_row0 + 16); + + weight[0][0] = vld1q_s8(weight_ptr); + weight[0][1] = vld1q_s8(weight_ptr + 16); + weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16); + +#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + int16x8_t tmp0; + src[2] = vld1q_s8(src_row0 + 2 * 16); + src[3] = vld1q_s8(src_row0 + 3 * 16); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); + + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]); + + src[0] = vld1q_s8(src_row0 + 4 * 16); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]); + + src[1] = vld1q_s8(src_row0 + 5 * 16); + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][3]); + CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][3]); + + src[2] = vld1q_s8(src_row0 + 6 * 16); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][4]); + + src[3] = vld1q_s8(src_row0 + 7 * 16); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][5]); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][5]); + + src[0] = vld1q_s8(src_row0 + 8 * 16); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][6]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][6]); + + src[1] = vld1q_s8(src_row1 + 0 * 16); + src[2] = vld1q_s8(src_row1 + 1 * 16); + + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][7]); + CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][7]); + + weight[0][0] = vld1q_s8(weight_ptr + 32); + weight[0][1] = vld1q_s8(weight_ptr + 48); + src[3] = vld1q_s8(src_row1 + 2 * 16); + + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][0]); + weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32); + weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48); + src[0] = vld1q_s8(src_row1 + 3 * 16); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][0]); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][1]); + + src[1] = vld1q_s8(src_row1 + 4 * 16); + + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][2]); + CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][2]); + + src[2] = vld1q_s8(src_row1 + 5 * 16); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][3]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][3]); + + src[3] = vld1q_s8(src_row1 + 6 * 16); + + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][4]); + CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][4]); + + src[0] = vld1q_s8(src_row1 + 7 * 16); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][5]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][5]); + + src[1] = vld1q_s8(src_row1 + 8 * 16); + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][6]); + CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][6]); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][7]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][7]); + + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_2_LINE_RESULT(); +} + +template +static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int src_expand_size = 4; + + const int ic_stride = ih * iw * src_expand_size; + + int16x4_t c[1][8]; + int8x16_t weight[1][2]; + int8x16_t src[4]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; + src[0] = vld1q_s8(src_ic_0_3); + src[1] = vld1q_s8(src_ic_0_3 + 16); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + int16x8_t tmp0; + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][0]); + + src[0] = vld1q_s8(src_ic_0_3 + 4 * 16); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], + c[0][1]); + + src[1] = vld1q_s8(src_ic_0_3 + 5 * 16); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][2]); + + src[2] = vld1q_s8(src_ic_0_3 + 6 * 16); + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], + c[0][3]); + + src[3] = vld1q_s8(src_ic_0_3 + 7 * 16); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][4]); + + src[0] = vld1q_s8(src_ic_0_3 + 8 * 16); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], + c[0][5]); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][6]); + + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], + c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + STORE_1_LINE_RESULT(); +} +#undef CALC_ONE_RESULT + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc); +}; + +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int filter_size = 3; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int src_expand_size = 4; + + const int ic_stride = ih * iw * src_expand_size; + + int16x4_t c[1][8]; + int8x16_t weight[3]; + int8x16_t src[5]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; + + src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _w0, _w1, _w2, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w2)); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + int16x8_t tmp0; + + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], + weight[2], c[0][0]); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], + weight[2], c[0][1]); + src[0] = vld1q_s8(src_ic_0_3 + 5 * 16); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], + weight[2], c[0][2]); + src[1] = vld1q_s8(src_ic_0_3 + 6 * 16); + CALC_ONE_RESULT(src[3], src[4], src[0], weight[0], weight[1], + weight[2], c[0][3]); + src[2] = vld1q_s8(src_ic_0_3 + 7 * 16); + CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], weight[1], + weight[2], c[0][4]); + src[3] = vld1q_s8(src_ic_0_3 + 8 * 16); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], + weight[2], c[0][5]); + src[4] = vld1q_s8(src_ic_0_3 + 9 * 16); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], + weight[2], c[0][6]); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], + weight[2], c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT(); + } +}; + +#undef CALC_ONE_RESULT +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int filter_size = 5; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int src_expand_size = 4; + + const int ic_stride = ih * iw * src_expand_size; + int16x4_t c[1][8]; + int8x16_t weight[5]; + int8x16_t src[8 + 2]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; + + src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ + _w4, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][0]); + CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][2]); + CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][3]); + CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][4]); + CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][5]); + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); + CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][6]); + CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], + weight[0], weight[1], weight[2], weight[3], + weight[4], c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + STORE_1_LINE_RESULT(); + } +}; +#undef CALC_ONE_RESULT +template +struct KerNeonDirectStride1Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int filter_size = 7; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + constexpr int src_expand_size = 4; + + const int ic_stride = ih * iw * src_expand_size; + + int16x4_t c[1][8]; + int8x16_t weight[7]; + int8x16_t src[8 + 2]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; + + src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); + src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); + src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); + src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); + src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); + src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); + src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); + src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); + + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); + weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); + weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); + weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ + _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + int16x8_t tmp0; + CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], weight, c[0][0]); + CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], + src[7], weight, c[0][1]); + src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); + src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); + + CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], + src[8], weight, c[0][2]); + CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], + src[9], weight, c[0][3]); + + src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); + src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); + CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], + src[0], weight, c[0][4]); + CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], + src[1], weight, c[0][5]); + CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], + src[2], weight, c[0][6]); + CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], + src[3], weight, c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT(); + } +}; + +template +void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, + int16_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow) { + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t src_expand_size = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_oc = oh * ow * oc_step; + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s1_oc4_ow8; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + size_t oh_idx = 0; + for (; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_oc); + } + } + } + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s1_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_oc); + } + } + } +} +#undef CALC_ONE_RESULT + +template +void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, + const int8_t* filter, + const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, + const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t src_expand_size = 4; + + const size_t img_stride = oh * ow; + const int ld_dst_oc = oh * ow * oc_step; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + + using remain_fun = + std::function; + + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_small_oc_remain = KerNeonDirectStride1Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * iw + ow_idx) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDirectStride1Int8::impl( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * iw + ow_end) * ic_step * src_expand_size; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc); + } + } + } +} +} // namespace + +namespace int8x8x16_direct_nchw44 { +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_int8_nchw44_kern( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; + +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ + template struct ConvDirectInt8Nchw44Choose; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(1); + +} // namespace int8x8x16_direct_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3e7686f941083e02570920afa8e3319fa0c09e4 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp @@ -0,0 +1,1560 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { + +#define INIT_SUM() \ + int16x4_t init_sum; \ + if (bias_mode == BiasMode::NO_BIAS) { \ + init_sum = vdup_n_s16(0); \ + } else { \ + init_sum = vld1_s16(bias_ptr); \ + } + +#define STORE_1_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 1 error remainw"); \ + break; \ + }; + +#define STORE_1_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 1 error remainw"); \ + break; \ + }; + +#define STORE_2_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 24, \ + vcombine_s16(c[1][6], c[1][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, \ + vcombine_s16(c[1][4], c[1][5])); \ + vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ + } + +#define STORE_2_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, \ + vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ + } + +template +static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + int16x4_t c[2][8]; + int8x16_t weight[2][2]; + int8x16_t src[4]; + INIT_SUM(); +#define cb(_i) \ + c[0][_i] = init_sum; \ + c[1][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); + +#undef cb + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + +#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + int16x8_t tmp0; + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], + c[1][0]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][1]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], + c[1][1]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][2]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], + c[1][2]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 36, idx); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][3]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], + c[1][3]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 40, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 44, idx); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][4]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], + c[1][4]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 48, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 52, idx); + + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][5]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], + c[1][5]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 56, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 60, idx); + + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], + c[0][6]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], + c[1][6]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], + c[0][7]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], + c[1][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_2_LINE_RESULT(); +} + +template +static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + int16x4_t c[1][8]; + int8x16_t weight[2]; + int8x16_t src[4]; + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(8, cb); + +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + int16x8_t tmp0; + CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][0]); + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + + CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][1]); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + + CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][2]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 36, idx); + + CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][3]); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 40, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 44, idx); + CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][4]); + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 48, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 52, idx); + CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][5]); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 56, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 60, idx); + CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][6]); + CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][7]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT(); +} +#undef CALC_ONE_RESULT + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + +template +static void ker_neon_dirctconv_3x3s2_oc8_ow4(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + const int ic_stride = ih * iw; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + int16x4_t c[2][4]; + int8x16_t weight[2][3]; + int8x16_t src[5]; + + INIT_SUM(); +#define cb(_i) \ + c[0][_i] = init_sum; \ + c[1][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); + +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[0][2] = vld1q_s8(read_weight_ptr + 32); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + weight[1][2] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 32); + + int16x8_t tmp0, tmp1; + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[1], c[1][0]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[1], c[1][1]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + + CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], c[0][2]); + CALC_ONE_RESULT(src[4], src[0], src[1], weight[1], c[1][2]); + + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][3]); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[1], c[1][3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); +} + +template +static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, + int ih, int iw, + int ld_dst_oc) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + const int ic_stride = ih * iw; + const int ld_weight_oc4 = oc_step * fh * fw * ic; + + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + int16x4_t c[2][4]; + int8x16_t weight[2][3]; + int8x16_t src[5]; + + INIT_SUM(); +#define cb(_i) \ + c[0][_i] = init_sum; \ + c[1][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); + +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0][0] = vld1q_s8(read_weight_ptr); + weight[0][1] = vld1q_s8(read_weight_ptr + 16); + weight[0][2] = vld1q_s8(read_weight_ptr + 32); + weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); + weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); + weight[1][2] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 32); + + int16x8_t tmp0, tmp1; + CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], src[2], weight[1], c[1][0]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], src[4], weight[1], c[1][1]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + + CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], c[0][2]); + CALC_ONE_RESULT(src[4], src[0], src[1], weight[1], c[1][2]); + + CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][3]); + CALC_ONE_RESULT(src[1], src[2], src[3], weight[1], c[1][3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_2_LINE_RESULT_OW4(); +} + +#undef CALC_ONE_RESULT + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + +template +static void ker_neon_dirctconv_3x3s2_oc4_ow4(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, + int iw, int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + int16x4_t c[1][4]; + int8x16_t weight[3]; + int8x16_t src[5]; + + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); + +#undef cb + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + CALC_ONE_RESULT(src[0], src[1], src[2], weight, c[0][0]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight, c[0][1]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + + CALC_ONE_RESULT(src[4], src[0], src[1], weight, c[0][2]); + + CALC_ONE_RESULT(src[1], src[2], src[3], weight, c[0][3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); +} +template +static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain(const int8_t* src_ptr, + const int8_t* weight_ptr, + const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, + int ih, int iw, + int /*ld_dst_oc*/) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; + constexpr int loop_ic_step = 4; + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + + int16x4_t c[1][4]; + int8x16_t weight[3]; + int8x16_t src[5]; + INIT_SUM(); +#define cb(_i) c[0][_i] = init_sum; + + UNROLL_CALL_RAW(4, cb); + +#undef cb + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step; + + src[0] = vld_dup_tbl_s32(src_ic_0_3, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx); + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); + src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); + const int8_t* read_weight_ptr = + weight_ptr + fh_idx * fw * ld_weight_ic4; + + weight[0] = vld1q_s8(read_weight_ptr); + weight[1] = vld1q_s8(read_weight_ptr + 16); + weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); + + CALC_ONE_RESULT(src[0], src[1], src[2], weight, c[0][0]); + + src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); + src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); + CALC_ONE_RESULT(src[2], src[3], src[4], weight, c[0][1]); + + src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); + src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); + + CALC_ONE_RESULT(src[4], src[0], src[1], weight, c[0][2]); + + CALC_ONE_RESULT(src[1], src[2], src[3], weight, c[0][3]); + } + weight_ptr += fh * fw * ld_weight_ic4; + } + STORE_1_LINE_RESULT_OW4(); +} + +#undef CALC_ONE_RESULT + +template +void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src, + const int8_t* filter, + const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, + const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + constexpr size_t filter_size = 2; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + + const size_t out_img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oh * ow * oc_step; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = ker_neon_dirctconv_2x2s2_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s2_oc4_ow8; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc); + } + } + } + + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc); + } + } + } +} + +template +void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, + const int8_t* filter, + const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, + const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + constexpr size_t filter_size = 3; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t big_oc_step = 8; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 4; + constexpr size_t ow_step4 = 4; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + + const size_t out_img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oh * ow * oc_step; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_3x3s2_oc8_ow4_remain; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_3x3s2_oc4_ow4_remain; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } +#undef cb + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step4) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s2_oc8_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc); + } + } + } + + if (oc_remain > 0) { + const size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s2_oc4_ow4( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); + } + if (ow_remain > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc); + } + } + } +} +#undef CALC_ONE_RESULT +#undef LOAD_SRC +template +void conv_direct_stride2_5x5_int8x8x16_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, + int16_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow) { + constexpr size_t filter_size = 5; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step1 = 1; + constexpr size_t ow_step = 4; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + const size_t remain_w = ow & 3; + + const size_t out_img_stride = oh * ow; + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + size_t oc_idx = 0; + + for (; oc_idx + 3 < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + const int16_t* bias_ptr = bias + oc_idx; + + int16x4_t init_sum; + + if (bias_mode == BiasMode::NO_BIAS) { + init_sum = vdup_n_s16(0); + } else { + init_sum = vld1_s16(bias_ptr); + } + size_t oh_idx = 0; + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + for (; oh_idx < oh; oh_idx += oh_step1) { + size_t ow_idx = 0; + for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + int16x4_t c[1][4]; + const int8_t* src_ptr = src + src_offset; + int16_t* dst_ptr = dst + dst_offset; + const int8_t* weight_ptr = filter + weight_offset; + + c[0][0] = init_sum; + c[0][1] = init_sum; + c[0][2] = init_sum; + c[0][3] = init_sum; +#if MEGDNN_AARCH64 + int8x16_t weight[3][5]; + int8x16_t ssrc[2][5]; +#else + int8x16_t weight[1][5]; + int8x16_t ssrc[1][9]; +#endif + for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = + src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; + const int8_t* src_row4 = + src_ptr + ic_idx * ic_stride + 4 * iw * ic_step; +#if MEGDNN_AARCH64 + +#define LOAD_SRC(_src, _src_ptr) \ + _src[0] = vld_dup_tbl_s32(_src_ptr, idx); \ + _src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ + _src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ + _src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ + _src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); + +#define LOAD_WEIGHT(_w, _w_ptr, _id0, _id1, _id2, _id3, _id4) \ + _w[0] = vld1q_s8(_w_ptr + _id0 * 16); \ + _w[1] = vld1q_s8(_w_ptr + _id1 * 16); \ + _w[2] = vld1q_s8(_w_ptr + _id2 * 16); \ + _w[3] = vld1q_s8(_w_ptr + _id3 * 16); \ + _w[4] = vld1q_s8(_w_ptr + _id4 * 16); + +#define CALC_4_RESULT(_src, _w, _src_ptr) \ + CALC_ONE_RESULT(_src[0], _src[1], _src[2], _src[3], _src[4], _w, c[0][0]); \ + _src[0] = vld_dup_tbl_s32(_src_ptr + 20, idx); \ + _src[1] = vld_dup_tbl_s32(_src_ptr + 24, idx); \ + CALC_ONE_RESULT(_src[2], _src[3], _src[4], _src[0], _src[1], _w, c[0][1]); \ + _src[2] = vld_dup_tbl_s32(_src_ptr + 28, idx); \ + _src[3] = vld_dup_tbl_s32(_src_ptr + 32, idx); \ + CALC_ONE_RESULT(_src[4], _src[0], _src[1], _src[2], _src[3], _w, c[0][2]); \ + _src[4] = vld_dup_tbl_s32(_src_ptr + 36, idx); \ + _src[0] = vld_dup_tbl_s32(_src_ptr + 40, idx); \ + CALC_ONE_RESULT(_src[1], _src[2], _src[3], _src[4], _src[0], _w, c[0][3]); + + int16x8_t tmp0, tmp1; + + LOAD_SRC(ssrc[0], src_row0); + LOAD_WEIGHT(weight[0], weight_ptr, 0, 1, 2, 3, 4); + LOAD_WEIGHT(weight[1], weight_ptr, 5, 6, 7, 8, 9); + CALC_4_RESULT(ssrc[0], weight[0], src_row0); + + LOAD_SRC(ssrc[1], src_row1); + LOAD_WEIGHT(weight[2], weight_ptr, 10, 11, 12, 13, 14); + LOAD_SRC(ssrc[0], src_row2); + CALC_4_RESULT(ssrc[1], weight[1], src_row1); + + LOAD_SRC(ssrc[1], src_row3); + LOAD_WEIGHT(weight[0], weight_ptr, 15, 16, 17, 18, 19); + CALC_4_RESULT(ssrc[0], weight[2], src_row2); + + LOAD_SRC(ssrc[0], src_row4); + LOAD_WEIGHT(weight[1], weight_ptr, 20, 21, 22, 23, 24); + CALC_4_RESULT(ssrc[1], weight[0], src_row3); + CALC_4_RESULT(ssrc[0], weight[1], src_row4); +#else + +#define LOAD_SRC(_src_ptr) \ + ssrc[0][0] = vld_dup_tbl_s32(_src_ptr, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ + ssrc[0][2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ + ssrc[0][3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ + ssrc[0][4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \ + ssrc[0][5] = vld_dup_tbl_s32(_src_ptr + 20, idx); \ + ssrc[0][6] = vld_dup_tbl_s32(_src_ptr + 24, idx); \ + ssrc[0][7] = vld_dup_tbl_s32(_src_ptr + 28, idx); \ + ssrc[0][8] = vld_dup_tbl_s32(_src_ptr + 32, idx); + +#define LOAD_WEIGHT(_w_ptr, _id0, _id1, _id2, _id3, _id4) \ + weight[0][0] = vld1q_s8(_w_ptr + _id0 * 16); \ + weight[0][1] = vld1q_s8(_w_ptr + _id1 * 16); \ + weight[0][2] = vld1q_s8(_w_ptr + _id2 * 16); \ + weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \ + weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16); + +#define CALC_4_RESULT(_src_ptr) \ + CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ + ssrc[0][4], weight[0], c[0][0]); \ + ssrc[0][0] = vld_dup_tbl_s32(_src_ptr + 36, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 40, idx); \ + CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ + ssrc[0][6], weight[0], c[0][1]); \ + CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ + ssrc[0][8], weight[0], c[0][2]); \ + CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \ + ssrc[0][1], weight[0], c[0][3]); + + int16x8_t tmp0, tmp1; + + LOAD_WEIGHT(weight_ptr, 0, 1, 2, 3, 4); + LOAD_SRC(src_row0); + CALC_4_RESULT(src_row0); + + LOAD_WEIGHT(weight_ptr, 5, 6, 7, 8, 9); + LOAD_SRC(src_row1); + CALC_4_RESULT(src_row1); + + LOAD_WEIGHT(weight_ptr, 10, 11, 12, 13, 14); + LOAD_SRC(src_row2); + CALC_4_RESULT(src_row2); + + LOAD_WEIGHT(weight_ptr, 15, 16, 17, 18, 19); + LOAD_SRC(src_row3); + CALC_4_RESULT(src_row3); + + LOAD_WEIGHT(weight_ptr, 20, 21, 22, 23, 24); + LOAD_SRC(src_row4); + CALC_4_RESULT(src_row4); +#endif + weight_ptr += fh * fw * ld_weight_ic4; + } + + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); + } + if (remain_w > 0) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + int16x4_t c[1][3]; + const int8_t* src_ptr = src + src_offset; + int16_t* dst_ptr = dst + dst_offset; + const int8_t* weight_ptr = filter + weight_offset; + + c[0][0] = init_sum; + c[0][1] = init_sum; + c[0][2] = init_sum; +#if MEGDNN_AARCH64 + int8x16_t weight[3][5]; + int8x16_t ssrc[2][5]; +#else + int8x16_t weight[1][5]; + int8x16_t ssrc[1][9]; +#endif + for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = + src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; + const int8_t* src_row4 = + src_ptr + ic_idx * ic_stride + 4 * iw * ic_step; +#if MEGDNN_AARCH64 + +#define LOAD_SRC(_src, _src_ptr) \ + _src[0] = vld_dup_tbl_s32(_src_ptr, idx); \ + _src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ + _src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ + _src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ + _src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); + +#define LOAD_WEIGHT(_w, _w_ptr, _id0, _id1, _id2, _id3, _id4) \ + _w[0] = vld1q_s8(_w_ptr + _id0 * 16); \ + _w[1] = vld1q_s8(_w_ptr + _id1 * 16); \ + _w[2] = vld1q_s8(_w_ptr + _id2 * 16); \ + _w[3] = vld1q_s8(_w_ptr + _id3 * 16); \ + _w[4] = vld1q_s8(_w_ptr + _id4 * 16); + +#define CALC_3_RESULT(_src, _w, _src_ptr) \ + CALC_ONE_RESULT(_src[0], _src[1], _src[2], _src[3], _src[4], _w, c[0][0]); \ + _src[0] = vld_dup_tbl_s32(_src_ptr + 20, idx); \ + _src[1] = vld_dup_tbl_s32(_src_ptr + 24, idx); \ + CALC_ONE_RESULT(_src[2], _src[3], _src[4], _src[0], _src[1], _w, c[0][1]); \ + _src[2] = vld_dup_tbl_s32(_src_ptr + 28, idx); \ + _src[3] = vld_dup_tbl_s32(_src_ptr + 32, idx); \ + CALC_ONE_RESULT(_src[4], _src[0], _src[1], _src[2], _src[3], _w, c[0][2]); + + int16x8_t tmp0, tmp1; + + LOAD_SRC(ssrc[0], src_row0); + LOAD_WEIGHT(weight[0], weight_ptr, 0, 1, 2, 3, 4); + LOAD_WEIGHT(weight[1], weight_ptr, 5, 6, 7, 8, 9); + CALC_3_RESULT(ssrc[0], weight[0], src_row0); + + LOAD_SRC(ssrc[1], src_row1); + LOAD_WEIGHT(weight[2], weight_ptr, 10, 11, 12, 13, 14); + LOAD_SRC(ssrc[0], src_row2); + CALC_3_RESULT(ssrc[1], weight[1], src_row1); + + LOAD_SRC(ssrc[1], src_row3); + LOAD_WEIGHT(weight[0], weight_ptr, 15, 16, 17, 18, 19); + CALC_3_RESULT(ssrc[0], weight[2], src_row2); + + LOAD_SRC(ssrc[0], src_row4); + LOAD_WEIGHT(weight[1], weight_ptr, 20, 21, 22, 23, 24); + CALC_3_RESULT(ssrc[1], weight[0], src_row3); + CALC_3_RESULT(ssrc[0], weight[1], src_row4); +#else + +#define LOAD_SRC(_src_ptr) \ + ssrc[0][0] = vld_dup_tbl_s32(_src_ptr, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ + ssrc[0][2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \ + ssrc[0][3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \ + ssrc[0][4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \ + ssrc[0][5] = vld_dup_tbl_s32(_src_ptr + 20, idx); \ + ssrc[0][6] = vld_dup_tbl_s32(_src_ptr + 24, idx); \ + ssrc[0][7] = vld_dup_tbl_s32(_src_ptr + 28, idx); \ + ssrc[0][8] = vld_dup_tbl_s32(_src_ptr + 32, idx); + +#define LOAD_WEIGHT(_w_ptr, _id0, _id1, _id2, _id3, _id4) \ + weight[0][0] = vld1q_s8(_w_ptr + _id0 * 16); \ + weight[0][1] = vld1q_s8(_w_ptr + _id1 * 16); \ + weight[0][2] = vld1q_s8(_w_ptr + _id2 * 16); \ + weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \ + weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16); + +#define CALC_3_RESULT(_src_ptr) \ + CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ + ssrc[0][4], weight[0], c[0][0]); \ + CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ + ssrc[0][6], weight[0], c[0][1]); \ + CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ + ssrc[0][8], weight[0], c[0][2]); + + int16x8_t tmp0, tmp1; + + LOAD_WEIGHT(weight_ptr, 0, 1, 2, 3, 4); + LOAD_SRC(src_row0); + CALC_3_RESULT(src_row0); + + LOAD_WEIGHT(weight_ptr, 5, 6, 7, 8, 9); + LOAD_SRC(src_row1); + CALC_3_RESULT(src_row1); + + LOAD_WEIGHT(weight_ptr, 10, 11, 12, 13, 14); + LOAD_SRC(src_row2); + CALC_3_RESULT(src_row2); + + LOAD_WEIGHT(weight_ptr, 15, 16, 17, 18, 19); + LOAD_SRC(src_row3); + CALC_3_RESULT(src_row3); + + LOAD_WEIGHT(weight_ptr, 20, 21, 22, 23, 24); + LOAD_SRC(src_row4); + CALC_3_RESULT(src_row4); +#endif + weight_ptr += fh * fw * ld_weight_ic4; + } + switch (remain_w) { + case 1: + vst1_s16(dst_ptr, c[0][0]); + break; + case 2: + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + break; + case 3: + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + vst1_s16(dst_ptr + 8, c[0][2]); + break; + default: + megdnn_throw("invalid remain_w"); + break; + } + } + } + } +} +#undef CALC_4_RESULT +#undef LOAD_SRC +#undef LOAD_WEIGHT +#undef CALC_ONE_RESULT + +template +void conv_direct_stride2_7x7_int8x8x16_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, + int16_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow) { + constexpr size_t filter_size = 7; + constexpr size_t fh = filter_size; + constexpr size_t fw = filter_size; + constexpr size_t ic_step = 4; + constexpr size_t oc_step = 4; + constexpr size_t oh_step1 = 1; + constexpr size_t ow_step = 4; + constexpr size_t stride_h = 2; + constexpr size_t stride_w = 2; + + const size_t out_img_stride = oh * ow; + static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 2, 2, 3, 3, 3, 3}; + static uint8x16_t idx = vld1q_u8(idx_buffer); + size_t oc_idx = 0; + + for (; oc_idx + 3 < oc; oc_idx += oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + const int16_t* bias_ptr = bias + oc_idx; + + int16x4_t init_sum; + + if (bias_mode == BiasMode::NO_BIAS) { + init_sum = vdup_n_s16(0); + } else { + init_sum = vld1_s16(bias_ptr); + } + size_t oh_idx = 0; + +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ + _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ + } while (0); + + for (; oh_idx < oh; oh_idx += oh_step1) { + size_t ow_idx = 0; + for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + int16x4_t c[1][4]; + int8x16_t weight[1][7]; + int8x16_t ssrc[1][9]; + const int8_t* src_ptr = src + src_offset; + int16_t* dst_ptr = dst + dst_offset; + const int8_t* weight_ptr = filter + weight_offset; + + c[0][0] = init_sum; + c[0][1] = init_sum; + c[0][2] = init_sum; + c[0][3] = init_sum; + for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = + src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; + const int8_t* src_row4 = + src_ptr + ic_idx * ic_stride + 4 * iw * ic_step; + const int8_t* src_row5 = + src_ptr + ic_idx * ic_stride + 5 * iw * ic_step; + const int8_t* src_row6 = + src_ptr + ic_idx * ic_stride + 6 * iw * ic_step; + +#define LOAD_SRC(_src) \ + ssrc[0][0] = vld_dup_tbl_s32(_src, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_src + 4, idx); \ + ssrc[0][2] = vld_dup_tbl_s32(_src + 8, idx); \ + ssrc[0][3] = vld_dup_tbl_s32(_src + 12, idx); \ + ssrc[0][4] = vld_dup_tbl_s32(_src + 16, idx); \ + ssrc[0][5] = vld_dup_tbl_s32(_src + 20, idx); \ + ssrc[0][6] = vld_dup_tbl_s32(_src + 24, idx); + +#define LOAD_WEIGHT(_id0, _id1, _id2, _id3, _id4, _id5, _id6) \ + weight[0][0] = vld1q_s8(weight_ptr + _id0 * 16); \ + weight[0][1] = vld1q_s8(weight_ptr + _id1 * 16); \ + weight[0][2] = vld1q_s8(weight_ptr + _id2 * 16); \ + weight[0][3] = vld1q_s8(weight_ptr + _id3 * 16); \ + weight[0][4] = vld1q_s8(weight_ptr + _id4 * 16); \ + weight[0][5] = vld1q_s8(weight_ptr + _id5 * 16); \ + weight[0][6] = vld1q_s8(weight_ptr + _id6 * 16); + +#define CALC_4_RESULT(_row) \ + CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ + ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c[0][0]); \ + \ + ssrc[0][7] = vld_dup_tbl_s32(_row + 28, idx); \ + ssrc[0][8] = vld_dup_tbl_s32(_row + 32, idx); \ + CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ + ssrc[0][6], ssrc[0][7], ssrc[0][8], weight[0], c[0][1]); \ + \ + ssrc[0][0] = vld_dup_tbl_s32(_row + 36, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_row + 40, idx); \ + \ + CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ + ssrc[0][8], ssrc[0][0], ssrc[0][1], weight[0], c[0][2]); \ + ssrc[0][2] = vld_dup_tbl_s32(_row + 44, idx); \ + ssrc[0][3] = vld_dup_tbl_s32(_row + 48, idx); \ + \ + CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \ + ssrc[0][1], ssrc[0][2], ssrc[0][3], weight[0], c[0][3]); + + int16x8_t tmp0, tmp1; + + LOAD_SRC(src_row0); + LOAD_WEIGHT(0, 1, 2, 3, 4, 5, 6); + CALC_4_RESULT(src_row0); + + LOAD_SRC(src_row1); + LOAD_WEIGHT(7, 8, 9, 10, 11, 12, 13); + CALC_4_RESULT(src_row1); + + LOAD_SRC(src_row2); + LOAD_WEIGHT(14, 15, 16, 17, 18, 19, 20); + CALC_4_RESULT(src_row2); + + LOAD_SRC(src_row3); + LOAD_WEIGHT(21, 22, 23, 24, 25, 26, 27); + CALC_4_RESULT(src_row3); + + LOAD_SRC(src_row4); + LOAD_WEIGHT(28, 29, 30, 31, 32, 33, 34); + CALC_4_RESULT(src_row4); + + LOAD_SRC(src_row5); + LOAD_WEIGHT(35, 36, 37, 38, 39, 40, 41); + CALC_4_RESULT(src_row5); + + LOAD_SRC(src_row6); + LOAD_WEIGHT(42, 43, 44, 45, 46, 47, 48); + CALC_4_RESULT(src_row6); + weight_ptr += fh * fw * ld_weight_ic4; + } + + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); + } + for (; ow_idx < ow; ow_idx++) { + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; + const size_t dst_offset = oc_idx * out_img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + constexpr int ld_weight_ic4 = 16; + + const int ic_stride = ih * iw; + int16x4_t c = init_sum; + int8x16_t weight[1][7]; + int8x16_t ssrc[1][7]; + const int8_t* src_ptr = src + src_offset; + int16_t* dst_ptr = dst + dst_offset; + const int8_t* weight_ptr = filter + weight_offset; + + for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = + src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; + const int8_t* src_row4 = + src_ptr + ic_idx * ic_stride + 4 * iw * ic_step; + const int8_t* src_row5 = + src_ptr + ic_idx * ic_stride + 5 * iw * ic_step; + const int8_t* src_row6 = + src_ptr + ic_idx * ic_stride + 6 * iw * ic_step; +#define CALC_1_RESULT(_row) \ + CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ + ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c); + + int16x8_t tmp0, tmp1; + LOAD_SRC(src_row0); + LOAD_WEIGHT(0, 1, 2, 3, 4, 5, 6); + CALC_1_RESULT(src_row0); + + LOAD_SRC(src_row1); + LOAD_WEIGHT(7, 8, 9, 10, 11, 12, 13); + CALC_1_RESULT(src_row1); + + LOAD_SRC(src_row2); + LOAD_WEIGHT(14, 15, 16, 17, 18, 19, 20); + CALC_1_RESULT(src_row2); + + LOAD_SRC(src_row3); + LOAD_WEIGHT(21, 22, 23, 24, 25, 26, 27); + CALC_1_RESULT(src_row3); + LOAD_SRC(src_row4); + LOAD_WEIGHT(28, 29, 30, 31, 32, 33, 34); + CALC_1_RESULT(src_row4); + LOAD_SRC(src_row5); + LOAD_WEIGHT(35, 36, 37, 38, 39, 40, 41); + CALC_1_RESULT(src_row5); + LOAD_SRC(src_row6); + LOAD_WEIGHT(42, 43, 44, 45, 46, 47, 48); + CALC_1_RESULT(src_row6); + + weight_ptr += fh * fw * ld_weight_ic4; + } + vst1_s16(dst_ptr, c); + } + } + } +} +#undef CALC_ONE_RESULT +#undef CALC_1_RESULT +#undef CALC_4_RESULT +#undef LOAD_SRC +#undef LOAD_WEIGHT +} // namespace + +namespace int8x8x16_direct_nchw44 { + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_2x2_int8_nchw44(src, filter, bias, dst, + oc, ic, ih, iw, oh, ow); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_3x3_int8_nchw44(src, filter, bias, dst, + oc, ic, ih, iw, oh, ow); + } +}; + +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_5x5_int8x8x16_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; +template +struct ConvDirectInt8Nchw44Choose { + static void impl(const int8_t* src, const int8_t* filter, + const int16_t* bias, int16_t* dst, const size_t oc, + const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_7x7_int8x8x16_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); + } +}; + +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \ + template struct ConvDirectInt8Nchw44Choose; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN(stride, filter, bias_mode) + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define DISPATCH_CONV_KERN(stride) \ + GET_BIAS_MODE_PARAM(stride, 2) \ + GET_BIAS_MODE_PARAM(stride, 3) \ + GET_BIAS_MODE_PARAM(stride, 5) \ + GET_BIAS_MODE_PARAM(stride, 7) + +DISPATCH_CONV_KERN(2); + +} // namespace int8x8x16_direct_nchw44 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 4daeb9c862e481bb48c9f615c99febf93dafe7c7..374bd03103a9cfdcc59cd5d6828425d301a7b45e 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -44,6 +44,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoQU8DirectStride1 qu8_direct_stride1; AlgoS8DirectStride2 s8_direct_stride2; AlgoS8DirectNCHW44 s8_direct_nchw44; + AlgoS8x8x16DirectNCHW44 s8x8x16_direct_nchw44; AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; AlgoS8DirectStride1 s8_direct_stride1; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; @@ -94,6 +95,7 @@ public: direct_algos.emplace_back(&qu8_direct_stride1); direct_algos.emplace_back(&s8_direct_stride2); direct_algos.emplace_back(&s8_direct_nchw44); + direct_algos.emplace_back(&s8x8x16_direct_nchw44); direct_algos.emplace_back(&s8_direct_nchw_nchw44); direct_algos.emplace_back(&s8_direct_stride1); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 4a6144ee3b9c08d5f0cbc7493151aaa67e2cbfe0..0176a9af9ebf8064589fcfd3e621b9cef7e22a96 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -39,6 +39,7 @@ private: class AlgoS8DirectStride1; class AlgoS8DirectStride2; class AlgoS8DirectNCHW44; + class AlgoS8x8x16DirectNCHW44; class AlgoS8DirectNCHWNCHW44; class AlgoQU8DirectStride1; class AlgoQU8DirectStride2; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 29f9be37688ff51199c915ed02b9af5ce56989e3..60c297beb06500665264661bb222b77096411250 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -518,6 +518,116 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, } } +void benchmark_nchw44_8x8x16_vs_8x8x32(const char* im2col_name, Handle* handle, + size_t kernel, size_t stride, + size_t pack_size = 1) { + megdnn_assert(stride == 1 || stride == 2, "only support stride 1 or 2"); + std::vector args; + auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p) { + if (ic % pack_size != 0 || oc % pack_size != 0) + return; + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.format = param::ConvBias::Format::NCHW44; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = p; + param.pad_w = p; + param.sparse = param::ConvBias::Sparse::DENSE; + args.push_back(conv_bias::TestArg{ + param, + TensorShape{1, ic / 4, h, w, 4}, + TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, + {1, oc / 4, 1, 1, 4}}); + }; + pack(1, 64, 56, 56, kernel, 0); + pack(8, 64, 56, 56, kernel, 0); + pack(16, 64, 56, 56, kernel, 1); + pack(32, 64, 56, 56, kernel, 1); + pack(1, 64, 100, 100, kernel, 1); + pack(8, 64, 100, 100, kernel, 1); + pack(1, 64, 100, 100, kernel, 0); + pack(8, 64, 100, 100, kernel, 0); + pack(16, 64, 100, 100, kernel, 1); + pack(32, 64, 100, 100, kernel, 1); + pack(64, 64, 100, 100, kernel, 1); + pack(128, 64, 100, 100, kernel, 1); + pack(256, 64, 100, 100, kernel, 1); + pack(512, 64, 100, 100, kernel, 1); + pack(1024, 64, 100, 100, kernel, 1); + pack(1, 32, 200, 200, kernel, 1); + pack(8, 64, 200, 200, kernel, 1); + pack(1, 32, 200, 200, kernel, 0); + pack(8, 64, 200, 200, kernel, 0); + pack(16, 96, 200, 200, kernel, 1); + pack(32, 32, 200, 200, kernel, 1); + pack(64, 64, 200, 200, kernel, 1); + pack(128, 96, 200, 200, kernel, 1); + pack(1, 64, 10, 10, kernel, 1); + pack(8, 64, 10, 10, kernel, 1); + pack(16, 64, 10, 10, kernel, 1); + pack(32, 64, 10, 10, kernel, 1); + pack(64, 64, 10, 10, kernel, 1); + pack(128, 64, 10, 10, kernel, 1); + pack(256, 64, 10, 10, kernel, 1); + pack(512, 64, 10, 10, kernel, 1); + pack(1024, 64, 10, 10, kernel, 1); + + using namespace conv_bias; + constexpr size_t RUN = 20; + + Benchmarker benchmark_im2col(handle); + benchmark_im2col.set_display(false); + benchmark_im2col.set_times(RUN); + + Benchmarker benchmark_8832(handle); + benchmark_8832.set_display(false); + benchmark_8832.set_times(RUN); + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 * 4 / + (1024 * 1024 * 1024) * 1e3; + + benchmark_im2col.set_param(arg.param); + benchmark_im2col.set_dtype(0, dtype::Int8()); + benchmark_im2col.set_dtype(1, dtype::Int8()); + benchmark_im2col.set_dtype(2, dtype::Int16()); + benchmark_im2col.set_dtype(4, dtype::Int16()); + auto used_8816 = + algo_benchmark(benchmark_im2col, + {arg.src, arg.filter, {}, {}, {}}, + im2col_name) / + RUN; + benchmark_8832.set_param(arg.param); + benchmark_8832.set_dtype(0, dtype::QuantizedS8(2.5)); + benchmark_8832.set_dtype(1, dtype::QuantizedS8(2.5)); + benchmark_8832.set_dtype(2, dtype::QuantizedS32(6.25)); + benchmark_8832.set_dtype(4, {}); + auto used_8832 = + algo_benchmark(benchmark_8832, + {arg.src, arg.filter, {}, {}, {}}, + "S8_NCHW44_DIRECT") / + RUN; + + printf("%s %s: 8816: %f ms %f GFlops ", arg.src.to_string().c_str(), + arg.filter.to_string().c_str(), used_8816, + computations / used_8816); + printf("%s %s: 8832: %f ms %f GFlops ", arg.src.to_string().c_str(), + arg.filter.to_string().c_str(), used_8832, + computations / used_8832); + printf("speedup %f \n", used_8832 / used_8816); + } +} + void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, const char* im2col_name, Handle* handle, size_t kernel, DType src_type, @@ -872,6 +982,28 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_MATMUL) { #endif #if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE1) { + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 1, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 1, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 1, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 1, + 4); +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE2) { + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 2, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 2, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 2, + 4); + benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 2, + 4); +} + TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { #if MEGDNN_AARCH64 benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index b918f439c07fe8c9ccb59d95ada8e5f9f7ab9fbd..f1fbe6f46184fc1c590878394ef9037cbde8a3ac 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -534,11 +534,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(), "S8_NCHW44_DIRECT"); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) { + checker_conv_bias_int8x8x16( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), + handle(), "S8x8x16_NCHW44_DIRECT"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) { + checker_conv_bias_int8x8x16( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), + handle(), "S8x8x16_NCHW44_DIRECT"); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { checker_conv_bias_qint8x8x32( get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), handle(), "S8_NCHW44_DIRECT"); } + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { checker_conv_bias_qint8x8x32( get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),