From 7b0dbe6af8ecc4c2d9c9e2cc8f1a5e37749cb62c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 3 Jun 2020 22:38:53 +0800 Subject: [PATCH] fix(dnn/arm): fix stride 1 support for int8 nchw_nchw44 GitOrigin-RevId: 9d718eb7a4dae3c2724ea07ba2b639fbfb319f78 --- .../conv_bias/fp32/f32_direct_nchw44_algo.cpp | 5 +- dnn/src/arm_common/conv_bias/int8/algos.h | 4 +- .../int8/direct_nchw_nchw44_algo.cpp | 373 +++++ .../conv_bias/int8/direct_nchw_nchw44_kern.h | 1287 +++++++++++++++++ .../int8/direct_stride2_nchw_nchw44_algo.cpp | 305 ---- .../int8/direct_stride2_nchw_nchw44_kern.cpp | 789 ---------- .../int8/direct_stride2_nchw_nchw44_kern.h | 44 - dnn/src/arm_common/conv_bias/opr_impl.cpp | 4 +- dnn/src/arm_common/conv_bias/opr_impl.h | 2 +- dnn/test/arm_common/conv_bias.cpp | 8 + .../arm_common/conv_bias_multi_thread.cpp | 7 +- 11 files changed, 1682 insertions(+), 1146 deletions(-) create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp delete mode 100644 dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index b1ee38aa9..fcc2aeb9a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, static void get_rectified_size( const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, int& iw2, int& oh2, int& ow2) { - constexpr int cacheline = 64 / sizeof(float); + constexpr int nr_elements_in_cacheline = 64 / sizeof(float); int ic = param.filter_meta.icpg; int iw = param.isz[1]; int oh = param.osz[0]; @@ -52,7 +52,8 @@ static void get_rectified_size( int block_oh = l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * stride_h); ih2 = block_oh * stride_h + filter_h - stride_h; - iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), cacheline); + iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), + nr_elements_in_cacheline); } static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 8da31ae54..47cd7b9b9 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -90,9 +90,9 @@ public: const NCBKernSizeParam& param) const override; }; -class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase { +class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { public: - AlgoS8DirectStride2NCHWNCHW44() {} + AlgoS8DirectNCHWNCHW44() {} bool is_reproducible() const override { return true; } const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp new file mode 100644 index 000000000..4fa25bfef --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -0,0 +1,373 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/int8/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44) + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2, int& oh2, int& ow2) { + auto&& fm = param.filter_meta; + int ih = param.isz[0]; + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + int ph = fm.padding[0]; + int pw = fm.padding[1]; + int stride_h = fm.stride[0]; + + oh2 = oh; + ow2 = ow; + ih2 = stride_h == 2 ? round_up(ih + 2 * ph, 2) : ih + 2 * ph; + iw2 = iw + 2 * pw; +} +static inline size_t get_temp_bytes(const int iw, const int pw) { + //! border_size is used to avoid read illegal memory + constexpr int cacheline_size = 64; + constexpr int border_size = 1 * cacheline_size; + + return round_up(iw + pw * 2, cacheline_size) + border_size; +} +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int group = fm.group; + int batch = param.n; + int ic = fm.icpg; + int oc = fm.ocpg; + int fh = fm.spatial[0]; + int fw = fm.spatial[1]; + int stride_h = fm.stride[0]; + int iw = param.isz[1]; + int pw = fm.padding[1]; + int ih2, iw2, oh2, ow2; + const size_t src_expand = stride_h == 2 ? 4 : 16; + get_rectified_size(param, ih2, iw2, oh2, ow2); + megdnn_assert(group == 1, "only support group == 1 now"); + size_t src_size = + batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand; + size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t); + size_t tmp_size = 0; + if (stride_h == 1) { + weight_size = group * oc * ic * fh * round_up(fw, 4) * sizeof(int8_t); + tmp_size = get_temp_bytes(iw, pw); + } + return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}}; +}; + +static void copy_padding_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids) { + int ih = kern_param.isz[0]; + int iw = kern_param.isz[1]; + int ic = kern_param.filter_meta.icpg; + int ph = kern_param.filter_meta.padding[0]; + int pw = kern_param.filter_meta.padding[1]; + int group = kern_param.filter_meta.group; + int stride_h = kern_param.filter_meta.stride[0]; + + int ih2, iw2, oh2, ow2; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2); + int padding_group_size = ih2 * iw2 * ic; + bundle.set(kern_param.workspace_ptr); + //! Used for get the workspace offset + const int src_expand = stride_h == 2 ? 4 : 16; + + //! TODO: block dim is better to get from arg + int workspace_ic_block = 1; + int workspace_batch_id = workspace_ids[0]; + int workspace_group_id = workspace_ids[1]; + int workspace_ic_id = workspace_ids[2]; + int workspace_ic = workspace_ic_id * workspace_ic_block; + int batch_id = ncb_index.ndrange_id[0]; + int group_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = static_cast( + kern_param.src(batch_id, group_id, workspace_ic_id, 1, 1)); + //! copy to sptr_base to eliminate padding effect + int8_t* sptr_base = static_cast(bundle.get(0)) + + (workspace_batch_id * group * padding_group_size + + workspace_group_id * padding_group_size + + workspace_ic * ih2 * iw2) * + src_expand; + if (stride_h == 1) { + const size_t tmp_size = get_temp_bytes(iw, pw); + int8_t* tmp_ptr = reinterpret_cast(bundle.get(2)) + + ncb_index.thread_id * tmp_size; + pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, + iw, iw2, pw, tmp_ptr); + } else { + pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, + iw, iw2, pw, nullptr); + } +} +static void pack_weight(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + bundle.set(kern_param.workspace_ptr); + const int group_id = ncb_index.ndrange_id[0]; + int fh = kern_param.filter_meta.spatial[0]; + int fw = kern_param.filter_meta.spatial[1]; + int oc = kern_param.filter_meta.ocpg; + int ic = kern_param.filter_meta.icpg; + int stride_h = kern_param.filter_meta.stride[0]; + int fw2 = stride_h == 2 ? fw : round_up(fw, 4); + int oc_block = oc; + int oc_idx = 0; + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * fh * fw * ic; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; + + if (stride_h == 1) { + pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw, + oc_block); + } else { + pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw, + oc_block); + } +} +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { + int oh = kern_param.osz[0]; + int ow = kern_param.osz[1]; + int fh = kern_param.filter_meta.spatial[0]; + int fw = kern_param.filter_meta.spatial[1]; + int fw2 = stride == 2 ? fw : round_up(fw, 4); + int ic = kern_param.filter_meta.icpg; + int oc = kern_param.filter_meta.ocpg; + int group = kern_param.filter_meta.group; + int ih2, iw2, oh2, ow2; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + int padding_group_size = ih2 * iw2 * ic; + bundle.set(kern_param.workspace_ptr); + + constexpr int pack_c = 4; + constexpr int src_expand_size = stride == 2 ? 4 : 16; + const int workspace_batch_id = workspace_ids[0]; + const int workspace_group_id = workspace_ids[1]; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + const int oc_id = ncb_index.ndrange_id[2]; + const int oc_block_num = ncb_range[2]; + int nr_pack_per_step = div_ceil(div_ceil(oc, pack_c), oc_block_num); + int oc_block = nr_pack_per_step * pack_c; + const int oc_idx = oc_id * oc_block; + if (oc_id == (oc_block_num - 1)) { + oc_block = oc - oc_id * nr_pack_per_step * pack_c; + } + megdnn_assert(oc_block % pack_c == 0, + "oc must be devisible by 4, but oc = %d", oc_block); + const int8_t* sptr = + static_cast(bundle.get(0)) + + workspace_batch_id * group * padding_group_size * src_expand_size + + workspace_group_id * padding_group_size * src_expand_size; + + int8_t* dst = reinterpret_cast( + reinterpret_cast( + kern_param.dst(batch_id, group_id)) + + oc_idx * oh * ow); + + const int32_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + int8_t* packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw2 + + oc_idx * ic * fh * fw2; + conv_direct_int8_nchw_nchw44( + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, + ow, op); +} + +bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + MEGDNN_MARK_USED_VAR(algo_selection_strategy); + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto OC = fm.ocpg; + bool avaible = //! src and filter are qint8, dst is qint8 + fm.icpg < 4 && // must be nchw input + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && + (fm.format == param::Convolution::Format::NCHW44) && + (OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && + param.bias_mode != BiasMode::BIAS; + return avaible; +} + +bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( + megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, + const NCBKernSizeParam& param) const { + // TODO: benchmark and fix + MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); + MEGDNN_MARK_USED_VAR(param); + return false; +} + +size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + size_t N = param.n; + size_t OC = fm.ocpg; + size_t group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; +// NOTE: remain_w is not used to gen hash of midout for compatible with changing +// shape runtime +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN(stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(stride, 5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(stride, 7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) + .c_str()); + break; + } + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + + constexpr size_t pack_oc = 8; + size_t oc_step = pack_oc; + auto copy_padding = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); + + auto do_pack_weight = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + pack_weight(bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); + + CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h new file mode 100644 index 000000000..bc2b2c1a2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h @@ -0,0 +1,1287 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace { +/** + * @brief core code for calculation patten + * + * @tparam src_idx is offset of src reg + * @tparam weight_idx is offset of weight reg + * @tparam c_dim is output channel + * @tparam Func mla operation funcion + * @tparam stride + * @tparam T outpur regs type + * @tparam T2 src regs type + * @tparam T3 weight regs type + * @tparam T4 temp regs type + */ +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp); + static void impl(T& c, T2& src, T3& weight); +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], + temp[1]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], + temp[3]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], + temp[1]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], + temp[3]); + } + static void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], + temp[0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], + temp[2]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], + temp[0]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], + temp[2]); + } + static void impl(T& c, T2& src, T3& weight) { + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], + c[0][0], temp[0]); + c[1][0] = Func::impl(src[(0 + src_idx) % 8], weight[1][weight_idx], + c[1][0], temp[1]); + c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx], + c[0][1], temp[2]); + c[1][1] = Func::impl(src[(1 + src_idx) % 8], weight[1][weight_idx], + c[1][1], temp[3]); + c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx], + c[0][2], temp[0]); + c[1][2] = Func::impl(src[(2 + src_idx) % 8], weight[1][weight_idx], + c[1][2], temp[1]); + c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx], + c[0][3], temp[2]); + c[1][3] = Func::impl(src[(3 + src_idx) % 8], weight[1][weight_idx], + c[1][3], temp[3]); + + c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx], + c[0][4], temp[0]); + c[1][4] = Func::impl(src[(4 + src_idx) % 8], weight[1][weight_idx], + c[1][4], temp[1]); + c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx], + c[0][5], temp[2]); + c[1][5] = Func::impl(src[(5 + src_idx) % 8], weight[1][weight_idx], + c[1][5], temp[3]); + c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx], + c[0][6], temp[0]); + c[1][6] = Func::impl(src[(6 + src_idx) % 8], weight[1][weight_idx], + c[1][6], temp[1]); + c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx], + c[0][7], temp[2]); + c[1][7] = Func::impl(src[(7 + src_idx) % 8], weight[1][weight_idx], + c[1][7], temp[3]); + } + static void impl(T&, T2&, T3&); +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight, T4& temp) { + c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx], + c[0][0], temp[0]); + c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx], + c[0][1], temp[1]); + c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx], + c[0][2], temp[2]); + c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx], + c[0][3], temp[3]); + c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx], + c[0][4], temp[0]); + c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx], + c[0][5], temp[1]); + c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx], + c[0][6], temp[2]); + c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx], + c[0][7], temp[3]); + } + static void impl(T&, T2&, T3&); +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) { + ShiftCalHelper::impl(c, src, weight, temp); +} +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; + +template +struct OCHelper { +public: + static const int val = 0; +}; +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op); +}; +/** + * filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2:3, 4] dot4, \ + * f0[0:1, 4:5, 4] dot4, \ + * f0[0:1, 6, 4] dot2, \ + * ... + * f0[6, 0:1, 4] dot2, \ + * f0[6, 2:3, 4] dot2, \ + * f0[6, 4:5, 4] dot2, \ + * f0[6, 6, 4] dot1, \ + * look like: + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |x x|x x|x x|x| + * |---|---|---|-| + * |x x|x x|x x|x| + * |---|---|---|-| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int filter_size = 7; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_step = 2; + constexpr int fh_end = filter_size / fh_step * fh_step; + constexpr int c_dim = OCHelper::val; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[6]; + int8x16_t dot4_weight[c_dim][3]; + int16x8_t temp_c[4]; + load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<2, 2, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + weight_ptr += filter_size * pack_iw_len * fh_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 6 * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][3]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[6]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 5; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int ih_step = 2; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_end = filter_size / ih_step * ih_step; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; + int8x16_t src[5]; + int8x16_t dot4_weight[c_dim][2]; + int16x8_t temp_c[4]; + load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + weight_ptr += filter_size * pack_iw_len * ih_step; + } + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + fh_end * iw * ic_step * pack_iw_len; + + int8x8_t dot2_weight[c_dim][2]; + int16x8_t temp_c[4]; + int8x8_t src_dot2[5]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_dot4_weight_oc); + load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, + 0, tbl); + + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, + dot2_weight, temp_c); + + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_dot4_weight_oc); + load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * pack_iw_len; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +/** + * filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] + * calculate sequence \ + * f0[0:1, 0:1, 4] dot4, \ + * f0[0:1, 2, 4] dot2, \ + * f0[2, 0:1, 4] dot2, \ + * f0[2, 2, 4] dot1 \ + * look like: + * |---|-| + * |x x|x| + * |x x|x| + * |-----| + * |x x|x| + * |-----| + **/ +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 3; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + // first 2 line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[4]; + int8x16_t dot4_weight[c_dim][1]; + int16x8_t temp_c[4]; + load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_weight_oc); + load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( + c, src, dot4_weight, temp_c); + + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, + 0); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + } + // last line + { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + 2 * iw * ic_step * pack_iw_len; + int16x8_t temp_c[4]; + int8x8_t src_dot2[4]; + int8x8_t dot2_weight[c_dim][1]; + uint8x16_t tbl = vld1q_u8(src_idx_buffer); + load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, + ld_weight_oc); + load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( + src_dot2, nchw_src_ptr, 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + int16x8_t dot1_weight[c_dim][1]; + int16x8_t src_dot1[4]; + load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( + dot1_weight, weight_ptr, ld_weight_oc); + load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, + nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, + dot1_weight); + weight_ptr += filter_size * filter_size * pack_iw_len; + } + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[4]; + int8x16_t dot4_weight[c_dim][1]; + int16x8_t temp_c[4]; + load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, + ld_weight_oc); + load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + weight_ptr += oc_step * filter_size * filter_size; + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 2; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, + ld_weight_oc); + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 3; + constexpr int filter_width = 4; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 1; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; + load_helper( + dot4_weight, weight_ptr, ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + load_helper( + dot4_weight, weight_ptr + 1 * filter_width * oc_step, + ld_weight_oc); + + load_helper( + src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + + load_helper( + dot4_weight, weight_ptr + 2 * filter_width * oc_step, + ld_weight_oc); + load_helper( + src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, + temp_c); + + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 5; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, \ + ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, \ + nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ + 0); \ + cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); + UNROLL_CALL_RAW(5, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_height = 7; + constexpr int filter_width = 8; + constexpr int oc_step = 4; + constexpr int loop_ic_step = 1; + constexpr int simd_len = 16; + constexpr int pack_iw_len = 16; + constexpr int src_reg = 8; + constexpr int weight_reg = 2; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_height * filter_width * ic; + constexpr int c_dim = OCHelper::val; + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + int8x16_t src[src_reg]; + int8x16_t dot4_weight[c_dim][weight_reg]; + int16x8_t temp_c[4]; +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, \ + ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, \ + nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ + 0); \ + cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); + + UNROLL_CALL_RAW(7, cb); +#undef cb + weight_ptr += oc_step * filter_height * filter_width; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; +template +inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad, + int right_pad, const int iw) { + const int8_t* src_row_0 = inptr; + const int8_t* src_row_1 = inptr + iw; + constexpr int combine_row = 2; + constexpr int iw_step = 16; + constexpr int src_expand = 4; + constexpr int out_gap = iw_step * src_expand; + const int iw_end = iw / iw_step * iw_step; + + memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * left_pad * src_expand; + + for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { + int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); + int8x16_t row1 = vdupq_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1q_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdupq_n_s8(0); + } + int8x16x2_t pack_rows = vzipq_s8(row0, row1); +#define STORE_8S8(step) \ + vst1_s8(outptr + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[0]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 +#define STORE_8S8(step) \ + vst1_s8(outptr + out_gap + step * 8, \ + vreinterpret_s8_s16(vdup_laneq_s16( \ + vreinterpretq_s16_s8(pack_rows.val[1]), step))); + + UNROLL_CALL_RAW(8, STORE_8S8); +#undef STORE_8S8 + outptr += out_gap * combine_row; + } + for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { + int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); + int8x8_t row1 = vdup_n_s8(0); + if (mode == PACK_MODE::NO_PAD) { + row1 = vld1_dup_s8(src_row_1 + iw_idx); + } else if (mode == PACK_MODE::FIRST_PAD) { + row1 = row0; + row0 = vdup_n_s8(0); + } + int8x8x2_t pack_rows = vzip_s8(row0, row1); + vst1_s8(outptr, pack_rows.val[0]); + outptr += src_expand * combine_row; + } + memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); + outptr += combine_row * right_pad * src_expand; +} +template +void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, + const int ic, const int top_pad, + const int bottom_pad, const int left_pad, + const int right_pad, const int ih, + const int iw, const int iw2, const int pw, + int8_t* temp_ptr); +/** + * pack (ic, h, w) to (ic, h / 2, 2 * w) + * pack interleave two adjacent row in src and repeat 4 times, store to one row + * */ +template <> +void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, + const int ic, const int top_pad, + const int bottom_pad, const int left_pad, + const int right_pad, const int ih, + const int iw, const int, const int, + int8_t*) { + constexpr int src_expand = 4; + constexpr int oh_step = 2; + const int oh = ih + top_pad + bottom_pad; + const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; + const int ow = (iw + left_pad + right_pad) * src_expand; + + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int oh_idx = 0; + for (; oh_idx < top_pad; oh_idx += oh_step) { + if (top_pad - oh_idx >= oh_step) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + + for (; oh_idx < oh_end; oh_idx += oh_step) { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += oh_step * iw; + outptr += oh_step * ow; + } + + for (; oh_idx < oh; oh_idx += oh_step) { + const int last_pad = oh_idx - ih - top_pad; + if (last_pad >= 0) { + memset(outptr, 0, oh_step * ow * sizeof(int8_t)); + } else { + pack_src_one_line(inptr, outptr, left_pad, + right_pad, iw); + inptr += iw; + } + outptr += oh_step * ow; + } + } +} +/** + * pack (ic, h, w) to (ic, h, w * 16) + * pack interleave two adjacent row in src and repeat 4 times, store to one row + * */ +template <> +void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, + int8_t* sptr_base, const int ic, + const int pad_top, const int pad_bottom, + const int, const int, const int ih, + const int iw, const int iw2, const int pw, + int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3}; + uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); + + constexpr int iw_step = 4; + constexpr int pack_iw_len = 16; + const int ic_stride = ih * iw; + const int iw_with_pad = iw + 2 * pw; + const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; + rep(ic_idx, ic) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * + pack_iw_len); + sptr_base += iw2 * pad_top * pack_iw_len; + rep(ih_idx, ih) { + memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); + memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); + for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { + int8x16_t src[4]; + int8x16_t dst[4]; + src[0] = vld1q_s8(temp_ptr + iw_idx); + src[1] = vld1q_s8(temp_ptr + iw_idx + 1); + src[2] = vld1q_s8(temp_ptr + iw_idx + 2); + src[3] = vld1q_s8(temp_ptr + iw_idx + 3); + dst[0] = vqtbl1q_s8(src[0], tbl_idx); + dst[1] = vqtbl1q_s8(src[1], tbl_idx); + dst[2] = vqtbl1q_s8(src[2], tbl_idx); + dst[3] = vqtbl1q_s8(src[3], tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); + } + for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { + int8x16_t src = vld1q_s8(temp_ptr + iw_idx); + int8x16_t dst = vqtbl1q_s8(src, tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst); + } + sptr_base += iw2 * pack_iw_len; + sptr += iw; + } + sptr_base += iw2 * pad_bottom * pack_iw_len; + } +} + +template +void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, + const int ic, const int fh, const int fw, + const int oc); +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} + * pack interleave two adjacent row in filter to one row + * */ +template <> +void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, + const int ic, const int fh, + const int fw, const int oc) { + constexpr int oc_step = 4; + constexpr int ic_step = 2; + constexpr int fh_step = 2; + constexpr int fw_step = 2; + const int ic_end = ic / ic_step * ic_step; + const int ic_remain = ic - ic_end; + const int fh_end = fh / fh_step * fh_step; + const int fh_remain = fh - fh_end; + const int fw_end = fw / fw_step * fw_step; + const int fw_remain = fw - fw_end; + const int filter_stride = ic * oc_step; + static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, + 4, 12, 5, 13, 6, 14, 7, 15}; + uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); + for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { + const int ic_offset = ic_idx * oc_step; + int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; + int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + vst1_s8(output_ic1, vget_high_s8(combine_row)); + output_ic0 += 8; + output_ic1 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + int8x8_t row_0 = vld1_s8(filter_ptr); + vst1_lane_s32((int32_t*)output_ic0, + vreinterpret_s32_s8(row_0), 0); + vst1_lane_s32((int32_t*)output_ic1, + vreinterpret_s32_s8(row_0), 1); + output_ic0 += 4; + output_ic1 += 4; + } + } + } + if (ic_remain > 0) { + const int ic_offset = ic_end * oc_step; + int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; + for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { + const int fh_offset = fh_idx * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + fw * filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + } + if (fh_remain > 0) { + const int fh_offset = fh_end * fw * filter_stride; + for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_idx * filter_stride + + ic_offset; + int8x8_t row_0 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr))); + int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( + (const int32_t*)(filter_ptr + filter_stride))); + int8x16_t combine_row = vcombine_s8(row_0, row_1); + combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); + vst1_s8(output_ic0, vget_low_s8(combine_row)); + output_ic0 += 8; + } + if (fw_remain > 0) { + const int8_t* filter_ptr = inptr + fh_offset + + fw_end * filter_stride + + ic_offset; + *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); + output_ic0 += 4; + } + } + } + inptr += oc_step * fh * fw * ic; + outptr += oc_step * fh * fw * ic; + } +} +/** + * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} + * pack interleave two adjacent row in filter to one row + * */ +template <> +void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, + const int ic, const int fh, + const int fw, const int oc) { + constexpr int oc_step = 4; + const int fw2 = round_up(fw, 4); + const int fw_remain = fw2 - fw; + const int dst_ic_stride = fh * fw2; + const int oc_step_stride = fh * fw2 * ic * oc_step; + static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15}; + uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); + rep_step(oc_idx, oc, oc_step) { + int32_t* dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + const int32_t* src_temp_ptr = reinterpret_cast( + src_ptr + oc_idx * ic * fh * fw); + // transpose ic and pad + rep(fh_idx, fh) { + rep(fw_idx, fw) { + rep(ic_idx, ic) { + *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; + src_temp_ptr++; + } + dst_temp_ptr++; + } + rep(ic_idx, ic) { + memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, + sizeof(int8_t) * oc_step * fw_remain); + } + dst_temp_ptr += fw_remain; + } + // transpose fw oc + int8_t* trans_dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + + rep_step(idx, oc_step_stride, 16) { + int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); + vst1q_s8(trans_dst_temp_ptr + idx, + vqtbl1q_s8(temp, tbl_transpose_4x4)); + } + } +}; +template +struct ConvDiectStrideInt8NchwNchw44 { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, int8_t* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr size_t fh = filter_size; + constexpr size_t fw = + stride == 2 ? filter_size : (filter_size + 3) / 4 * 4; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = stride == 2 ? 2 : 1; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = stride == 2 ? 4 : 8; + constexpr size_t stride_h = stride; + constexpr size_t stride_w = stride; + constexpr int pack_iw_len = 4; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + break; + + UNROLL_CALL_RAW(4, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, + filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + } +}; + +template +struct ConvDiectStrideInt8NchwNchw44 { + static void impl(const int8_t* src, const int8_t* filter, + const int32_t* bias, int32_t* temp, int8_t* dst, + const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int stride = 1; + constexpr size_t fh = filter_size; + constexpr size_t fw = (filter_size + 3) / 4 * 4; + constexpr size_t ic_step = 1; + constexpr size_t big_oc_step = 8; + constexpr size_t oc_step = 4; + constexpr size_t ih_step = 1; + constexpr size_t oh_step = 1; + constexpr size_t ow_step = 8; + constexpr size_t stride_h = stride; + constexpr size_t stride_w = stride; + constexpr int pack_iw_len = 16; + + const size_t img_stride = oh * ow; + const size_t ow_end = ow / ow_step * ow_step; + const size_t ow_remain = ow - ow_end; + const size_t oc_end = oc / big_oc_step * big_oc_step; + const size_t oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %zu for kern", ow_remain); + } + + for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + size_t oc_idx = oc_end; + const size_t weight_offset = oc_idx * ic * fh * fw; + for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { + for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const size_t src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, + filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + } +}; + +template +static void conv_direct_int8_nchw_nchw44(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, + const size_t ic, const size_t ih, + const size_t iw, const size_t oh, + const size_t ow, const Op& op) { + ConvDiectStrideInt8NchwNchw44::impl( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); +} + +} // namespace +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp deleted file mode 100644 index 2c1384a7d..000000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/int8/algos.h" -#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" -#include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/elemwise_op.h" -#include "src/common/opr_delegate.h" - -#include "midout.h" - -using namespace megdnn; -using namespace arm_common; -using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2) - -static void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { - auto&& fm = param.filter_meta; - size_t IH = param.isz[0]; - size_t IW = param.isz[1]; - size_t OH = param.osz[0]; - size_t OW = param.osz[1]; - - OH2 = OH; - OW2 = OW; - IH2 = round_up(IH + 2 * fm.padding[0], static_cast(2)); - IW2 = IW + 2 * fm.padding[1]; -} -static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { - constexpr size_t src_expand = 4; - auto&& fm = param.filter_meta; - size_t group = fm.group; - size_t batch = param.n; - size_t IC = fm.icpg; - size_t OC = fm.ocpg; - size_t FH = fm.spatial[0]; - size_t FW = fm.spatial[1]; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(param, IH2, IW2, OH2, OW2); - megdnn_assert(group == 1, "only support group == 1 now"); - size_t src_size = - batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; - size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); - return {nullptr, {src_size, weight_size}}; -}; - -static void copy_padding_kern(WorkspaceBundle bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { - size_t IH = kern_param.isz[0]; - size_t IW = kern_param.isz[1]; - size_t IC = kern_param.filter_meta.icpg; - size_t PH = kern_param.filter_meta.padding[0]; - size_t PW = kern_param.filter_meta.padding[1]; - size_t GROUP = kern_param.filter_meta.group; - - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - size_t padding_group_size = IH2 * IW2 * IC; - bundle.set(kern_param.workspace_ptr); - //! Used for get the workspace offset - constexpr int expend_element = 4; - // TODO: block dim is better to get from arg - size_t workspace_ic_block = 1; - size_t workspace_batch_id = workspace_ids[0]; - size_t workspace_group_id = workspace_ids[1]; - size_t workspace_ic_id = workspace_ids[2]; - size_t workspace_ic = workspace_ic_id * workspace_ic_block; - size_t batch_id = ncb_index.ndrange_id[0]; - size_t group_id = ncb_index.ndrange_id[1]; - - const int8_t* sptr = static_cast( - kern_param.src(batch_id, group_id, workspace_ic_id, 1, 1)); - //! copy to sptr_base to eliminate padding effect - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * IH2 * IW2) * - expend_element; - conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW, - IH, IW); -} - -template -static void do_conv_kern(WorkspaceBundle bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids, - const CpuNDRange& ncb_range) { - size_t OH = kern_param.osz[0]; - size_t OW = kern_param.osz[1]; - size_t FH = kern_param.filter_meta.spatial[0]; - size_t FW = kern_param.filter_meta.spatial[1]; - size_t IC = kern_param.filter_meta.icpg; - size_t OC = kern_param.filter_meta.ocpg; - size_t GROUP = kern_param.filter_meta.group; - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; - //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) - Op op = Op(1.0f, 4.0f); - if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = kern_param.dst_type.param().scale; - op = Op(scale_bias, scale_dst); - } - size_t padding_group_size = IH2 * IW2 * IC; - bundle.set(kern_param.workspace_ptr); - - constexpr size_t pack_c = 4; - constexpr size_t src_expand_size = 4; - const size_t workspace_batch_id = workspace_ids[0]; - const size_t workspace_group_id = workspace_ids[1]; - const size_t batch_id = ncb_index.ndrange_id[0]; - const size_t group_id = ncb_index.ndrange_id[1]; - const size_t oc_id = ncb_index.ndrange_id[2]; - const size_t oc_block_num = ncb_range[2]; - size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); - size_t oc_block = nr_pack_per_step * pack_c; - const size_t oc_idx = oc_id * oc_block; - if (oc_id == (oc_block_num - 1)) { - oc_block = OC - oc_id * nr_pack_per_step * pack_c; - } - megdnn_assert(oc_block % pack_c == 0, - "oc must be devisible by 4, but oc = %zu", oc_block); - const int8_t* sptr = - static_cast(bundle.get(0)) + - workspace_batch_id * GROUP * padding_group_size * src_expand_size + - workspace_group_id * padding_group_size * src_expand_size; - - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * FH * FW * IC; - void* dst = reinterpret_cast( - reinterpret_cast( - kern_param.dst(batch_id, group_id)) + - oc_idx * OH * OW); - const int32_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; - auto packed_weight = reinterpret_cast(bundle.get(1)) + - group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; - - conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW, - oc_block); -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \ - bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \ - static_cast(dst), oc_block, IC, IH2, IW2, \ - OH, OW, op) - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); -#undef KERN1_NCHW44_CONV -} - -/* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - MEGDNN_MARK_USED_VAR(algo_selection_strategy); - auto&& fm = param.filter_meta; - auto FH = fm.spatial[0]; - auto OC = fm.ocpg; - bool avaible = //! src and filter are qint8, dst is qint8 - fm.icpg < 4 && // must be nchw input - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && - (fm.format == param::Convolution::Format::NCHW44) && - (OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) && - fm.group == 1 && param.bias_mode != BiasMode::BIAS; - return avaible; -} - -bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred( - megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, - const NCBKernSizeParam& param) const { - // TODO: benchmark and fix - MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); - MEGDNN_MARK_USED_VAR(param); - return false; -} - -size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { - return get_bundle(param).total_size_in_bytes(); -} - -SmallVector -ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { - auto fm = param.filter_meta; - size_t N = param.n; - size_t OC = fm.ocpg; - size_t group = fm.group; - WorkspaceBundle wbundle = get_bundle(param); - conv_fun do_conv_fun = nullptr; -// NOTE: remain_w is not used to gen hash of midout for compatible with changing -// shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \ - midout_iv(#filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ - MIDOUT_END(); - -#define GET_OP_PARAM(filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(filter, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(filter, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(filter, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } -#define GET_BIAS_MODE_PARAM(filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } - -#define DISPATCH_CONV_KERN() \ - switch (param.filter_meta.spatial[0]) { \ - case 3: \ - GET_BIAS_MODE_PARAM(3) \ - break; \ - case 5: \ - GET_BIAS_MODE_PARAM(5) \ - break; \ - case 7: \ - GET_BIAS_MODE_PARAM(7) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ - } - - DISPATCH_CONV_KERN(); - -#undef DO_CONV_KERN_FUN -#undef GET_REMAIN_W_PARAM -#undef GET_OP_PARAM -#undef GET_BIAS_MODE_PARAM -#undef DISPATCH_CONV_KERN - - megdnn_assert(do_conv_fun); - - SmallVector ret_kerns; - WorkspaceBundle bundle = wbundle; - - constexpr size_t pack_oc = 8; - size_t oc_step = pack_oc; - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { - copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); - }; - ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); - - CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; - auto do_conv = [bundle, do_conv_fun, ncb_range]( - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); - }; - ret_kerns.push_back({do_conv, ncb_range}); - - return ret_kerns; -} - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp deleted file mode 100644 index 0a816d957..000000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp +++ /dev/null @@ -1,789 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/unroll_macro.h" -#include "src/common/utils.h" -#include "src/fallback/conv_bias/common.h" - -using namespace megdnn; -using namespace arm_common; -namespace { - -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight, T4& temp); - static void impl(T& c, T2& src, T3& weight); -}; -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], - temp[1]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], - temp[3]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], - temp[1]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); - c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], - temp[3]); - } - static void impl(T& c, T2& src, T3& weight) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); - c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); - c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); - c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); - c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); - } -}; -template -struct ShiftCalHelper { - static void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); - } - static void impl(T& c, T2& src, T3& weight) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]); - } -}; - -template -inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) { - ShiftCalHelper::impl( - c, src, weight, temp); -} -template -inline void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( - c, src, weight); -}; - -template -struct OCHelper { -public: - static const int val = 0; -}; -template <> -struct OCHelper<4> { -public: - static const int val = 1; -}; -template <> -struct OCHelper<8> { -public: - static const int val = 2; -}; - -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op); -}; -/** - * filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :] - * calculate sequence \ - * f0[0:1, 0:1, 4] dot4, \ - * f0[0:1, 2:3, 4] dot4, \ - * f0[0:1, 4:5, 4] dot4, \ - * f0[0:1, 6, 4] dot2, \ - * ... - * f0[6, 0:1, 4] dot2, \ - * f0[6, 2:3, 4] dot2, \ - * f0[6, 4:5, 4] dot2, \ - * f0[6, 6, 4] dot1, \ - * look like: - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |x x|x x|x x|x| - * |---|---|---|-| - * |x x|x x|x x|x| - * |---|---|---|-| - **/ -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int filter_size = 7; - constexpr int ic_step = 1; - constexpr int oc_step = 4; - constexpr int pack_iw_len = 4; - constexpr int fh_step = 2; - constexpr int fh_end = filter_size / fh_step * fh_step; - constexpr int c_dim = OCHelper::val; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; - - int32x4_t c[c_dim][4]; - - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - int8x16_t src[6]; - int8x16_t dot4_weight[c_dim][3]; - int16x8_t temp_c[4]; - load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - weight_ptr += filter_size * pack_iw_len * fh_step; - } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 6 * iw * ic_step * pack_iw_len; - - int8x8_t dot2_weight[c_dim][3]; - int16x8_t temp_c[4]; - int8x8_t src_dot2[6]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); - weight_ptr += filter_size * pack_iw_len; - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 5; - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int ih_step = 2; - constexpr int ic_step = 1; - constexpr int oc_step = 4; - constexpr int pack_iw_len = 4; - constexpr int fh_end = filter_size / ih_step * ih_step; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; - constexpr int c_dim = OCHelper::val; - int32x4_t c[c_dim][4]; - - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; - int8x16_t src[5]; - int8x16_t dot4_weight[c_dim][2]; - int16x8_t temp_c[4]; - load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - weight_ptr += filter_size * pack_iw_len * ih_step; - } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - fh_end * iw * ic_step * pack_iw_len; - - int8x8_t dot2_weight[c_dim][2]; - int16x8_t temp_c[4]; - int8x8_t src_dot2[5]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); - - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - - cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); - weight_ptr += filter_size * pack_iw_len; - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; -/** - * filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :] - * calculate sequence \ - * f0[0:1, 0:1, 4] dot4, \ - * f0[0:1, 2, 4] dot2, \ - * f0[2, 0:1, 4] dot2, \ - * f0[2, 2, 4] dot1 \ - * look like: - * |---|-| - * |x x|x| - * |x x|x| - * |-----| - * |x x|x| - * |-----| - **/ -template -struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { - constexpr int filter_size = 3; - static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, - 0, 8, 0, 8, 0, 8, 0, 8}; - constexpr int oc_step = 4; - constexpr int ic_step = 1; - constexpr int loop_ic_step = 1; - constexpr int pack_iw_len = 4; - - const int ic_stride = ih * iw * pack_iw_len; - const int ld_weight_oc = oc_step * filter_size * filter_size * ic; - constexpr int c_dim = OCHelper::val; - - int32x4_t c[c_dim][4]; - init_ocx_ow4(c, bias_ptr, oc_step); - - for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - // first 2 line - { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; - int8x16_t src[4]; - int8x16_t dot4_weight[c_dim][1]; - int16x8_t temp_c[4]; - load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_weight_oc); - load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight, - temp_c); - - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( - dot2_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - } - // last line - { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 2 * iw * ic_step * pack_iw_len; - int16x8_t temp_c[4]; - int8x8_t src_dot2[4]; - int8x8_t dot2_weight[c_dim][1]; - uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_weight_oc); - load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( - src_dot2, nchw_src_ptr, 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight, - temp_c); - int16x8_t dot1_weight[c_dim][1]; - int16x8_t src_dot1[4]; - load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( - dot1_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight); - weight_ptr += filter_size * filter_size * pack_iw_len; - } - } - store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); - } -}; - -} // namespace -enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; -template -inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad, - int right_pad, const int iw) { - const int8_t* src_row_0 = inptr; - const int8_t* src_row_1 = inptr + iw; - constexpr int combine_row = 2; - constexpr int iw_step = 16; - constexpr int src_expand = 4; - constexpr int out_gap = iw_step * src_expand; - const int iw_end = iw / iw_step * iw_step; - - memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t)); - outptr += combine_row * left_pad * src_expand; - - for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) { - int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx); - int8x16_t row1 = vdupq_n_s8(0); - if (mode == PACK_MODE::NO_PAD) { - row1 = vld1q_s8(src_row_1 + iw_idx); - } else if (mode == PACK_MODE::FIRST_PAD) { - row1 = row0; - row0 = vdupq_n_s8(0); - } - int8x16x2_t pack_rows = vzipq_s8(row0, row1); -#define STORE_8S8(step) \ - vst1_s8(outptr + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[0]), step))); - - UNROLL_CALL_RAW(8, STORE_8S8); -#undef STORE_8S8 -#define STORE_8S8(step) \ - vst1_s8(outptr + out_gap + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[1]), step))); - - UNROLL_CALL_RAW(8, STORE_8S8); -#undef STORE_8S8 - outptr += out_gap * combine_row; - } - for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) { - int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx); - int8x8_t row1 = vdup_n_s8(0); - if (mode == PACK_MODE::NO_PAD) { - row1 = vld1_dup_s8(src_row_1 + iw_idx); - } else if (mode == PACK_MODE::FIRST_PAD) { - row1 = row0; - row0 = vdup_n_s8(0); - } - int8x8x2_t pack_rows = vzip_s8(row0, row1); - vst1_s8(outptr, pack_rows.val[0]); - outptr += src_expand * combine_row; - } - memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t)); - outptr += combine_row * right_pad * src_expand; -} -/** - * pack (ic, h, w) to (ic, h / 2, 2 * w) - * pack interleave two adjacent row in src and repeat 4 times, store to one row - * */ -void conv_bias::pack_nchw_src_for_nchw44_conv( - const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, - const int bottom_pad, const int left_pad, const int right_pad, - const int ih, const int iw) { - constexpr int src_expand = 4; - constexpr int oh_step = 2; - const int oh = ih + top_pad + bottom_pad; - const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step; - const int ow = (iw + left_pad + right_pad) * src_expand; - - for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { - int oh_idx = 0; - for (; oh_idx < top_pad; oh_idx += oh_step) { - if (top_pad - oh_idx >= oh_step) { - memset(outptr, 0, oh_step * ow * sizeof(int8_t)); - } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += iw; - } - outptr += oh_step * ow; - } - - for (; oh_idx < oh_end; oh_idx += oh_step) { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += oh_step * iw; - outptr += oh_step * ow; - } - - for (; oh_idx < oh; oh_idx += oh_step) { - const int last_pad = oh_idx - ih - top_pad; - if (last_pad >= 0) { - memset(outptr, 0, oh_step * ow * sizeof(int8_t)); - } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); - inptr += iw; - } - outptr += oh_step * ow; - } - } -} - -/** - * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)} - * pack interleave two adjacent row in filter to one row - * */ -void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, - int8_t* outptr, const int ic, - const int fh, const int fw, - const int oc) { - constexpr int oc_step = 4; - constexpr int ic_step = 2; - constexpr int fh_step = 2; - constexpr int fw_step = 2; - const int ic_end = ic / ic_step * ic_step; - const int ic_remain = ic - ic_end; - const int fh_end = fh / fh_step * fh_step; - const int fh_remain = fh - fh_end; - const int fw_end = fw / fw_step * fw_step; - const int fw_remain = fw - fw_end; - const int filter_stride = ic * oc_step; - static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11, - 4, 12, 5, 13, 6, 14, 7, 15}; - uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer); - for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { - for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) { - const int ic_offset = ic_idx * oc_step; - int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step; - int8_t* output_ic1 = output_ic0 + fh * fw * oc_step; - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int fh_offset = fh_idx * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - vst1_s8(output_ic1, vget_high_s8(combine_row)); - output_ic0 += 8; - output_ic1 += 8; - } - } - if (fh_remain > 0) { - const int fh_offset = fh_end * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - vst1_s8(output_ic1, vget_high_s8(combine_row)); - output_ic0 += 8; - output_ic1 += 8; - } - if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; - int8x8_t row_0 = vld1_s8(filter_ptr); - vst1_lane_s32((int32_t*)output_ic0, - vreinterpret_s32_s8(row_0), 0); - vst1_lane_s32((int32_t*)output_ic1, - vreinterpret_s32_s8(row_0), 1); - output_ic0 += 4; - output_ic1 += 4; - } - } - } - if (ic_remain > 0) { - const int ic_offset = ic_end * oc_step; - int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step; - for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int fh_offset = fh_idx * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vreinterpret_s8_s32( - vld1_dup_s32((const int32_t*)(filter_ptr))); - int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( - (const int32_t*)(filter_ptr + fw * filter_stride))); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - output_ic0 += 8; - } - } - if (fh_remain > 0) { - const int fh_offset = fh_end * fw * filter_stride; - for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; - int8x8_t row_0 = vreinterpret_s8_s32( - vld1_dup_s32((const int32_t*)(filter_ptr))); - int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( - (const int32_t*)(filter_ptr + filter_stride))); - int8x16_t combine_row = vcombine_s8(row_0, row_1); - combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); - vst1_s8(output_ic0, vget_low_s8(combine_row)); - output_ic0 += 8; - } - if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; - *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); - output_ic0 += 4; - } - } - } - inptr += oc_step * fh * fw * ic; - outptr += oc_step * fh * fw * ic; - } -} - -template -static void conv_direct_stride2_int8_nchw_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(temp); - constexpr size_t fh = filter_size; - constexpr size_t fw = filter_size; - constexpr size_t ic_step = 1; - constexpr size_t big_oc_step = 8; - constexpr size_t oc_step = 4; - constexpr size_t ih_step = 2; - constexpr size_t oh_step = 1; - constexpr size_t ow_step = 4; - constexpr size_t stride_h = 2; - constexpr size_t stride_w = 2; - constexpr int pack_iw_len = 4; - - const size_t img_stride = oh * ow; - const size_t ow_end = ow / ow_step * ow_step; - const size_t ow_remain = ow - ow_end; - const size_t oc_end = oc / big_oc_step * big_oc_step; - const size_t oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = - std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - break; - - UNROLL_CALL_RAW(4, cb); - default: - megdnn_assert(0, "no remain %zu for kern", ow_remain); - } - - for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); - } - } - } - if (oc_remain > 0) { - size_t oc_idx = oc_end; - const size_t weight_offset = oc_idx * ic * fh * fw; - for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { - for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); - } - } - } -} -#define CONSTRUCT_FUNC(filter_size) \ - template \ - void conv_bias:: \ - conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \ - const int8_t* src, const int8_t* filter, \ - const int32_t* bias, int32_t* temp, int8_t* dst, \ - const size_t oc, const size_t ic, const size_t ih, \ - const size_t iw, const size_t oh, const size_t ow, \ - const Op& op) { \ - conv_direct_stride2_int8_nchw_nchw44( \ - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \ - } - -CONSTRUCT_FUNC(3); -CONSTRUCT_FUNC(5); -CONSTRUCT_FUNC(7); -#undef CONSTRUCT_FUNC - -template -void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, - const size_t ih, const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - MEGDNN_MARK_USED_VAR(src); - MEGDNN_MARK_USED_VAR(filter); - MEGDNN_MARK_USED_VAR(bias); - MEGDNN_MARK_USED_VAR(temp); - MEGDNN_MARK_USED_VAR(dst); - MEGDNN_MARK_USED_VAR(oc); - MEGDNN_MARK_USED_VAR(ic); - MEGDNN_MARK_USED_VAR(ih); - MEGDNN_MARK_USED_VAR(iw); - MEGDNN_MARK_USED_VAR(oh); - MEGDNN_MARK_USED_VAR(ow); - MEGDNN_MARK_USED_VAR(op); - megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); -} - -#define INSTANTIATION(stride, i, bias, Op) \ - template void conv_bias:: \ - conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44( \ - const int8_t*, const int8_t*, const int32_t*, int32_t*, \ - int8_t*, const size_t, const size_t, const size_t, \ - const size_t, const size_t, const size_t, const Op&); - -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(stride, i, bias, TypeCvtOp) \ - INSTANTIATION(stride, i, bias, ReluOp) \ - INSTANTIATION(stride, i, bias, HSwishOp) - -#define FOR_BIAS(stride, i) \ - FOR_OP(stride, i, BiasMode::NO_BIAS) \ - FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h deleted file mode 100644 index a0f65a65a..000000000 --- a/dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/common.h" - -namespace megdnn { -namespace arm_common { -namespace conv_bias { -#define KERN(stride, i, layout) \ - template \ - void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \ - const int8_t* src, const int8_t* filter, const int32_t* bias, \ - int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ - const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW, const Op& op); - -KERN(stride2, 2, nchw44) -KERN(stride2, 3, nchw44) -KERN(stride2, 5, nchw44) -KERN(stride2, 7, nchw44) -#undef KERN - -void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, - const int ic, const int fh, const int fw, - const int oc); - -void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, - const int ic, const int top_pad, - const int bottom_pad, const int left_pad, - const int right_pad, const int ih, - const int iw); -} // namespace conv_bias -} // namespace arm_common -} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index caed88db6..e7738120c 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; - AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44; + AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; @@ -115,7 +115,7 @@ public: direct_algos.emplace_back(&s8_direct_stride2_large_group); direct_algos.emplace_back(&s8_direct_stride2_small_group); direct_algos.emplace_back(&s8_direct_stride2_nchw44); - direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44); + direct_algos.emplace_back(&s8_direct_nchw_nchw44); direct_algos.emplace_back(&s8_direct_stride1_large_group); direct_algos.emplace_back(&s8_direct_stride1_small_group); direct_algos.emplace_back(&s8_direct_stride1_nchw44); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 3482be1dc..99f6f51e0 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -40,7 +40,7 @@ private: class AlgoS8DirectStride1NCHW44; class AlgoS8DirectStride2; class AlgoS8DirectStride2NCHW44; - class AlgoS8DirectStride2NCHWNCHW44; + class AlgoS8DirectNCHWNCHW44; class AlgoQU8DirectStride1; class AlgoQU8DirectStride2; class AlgoFP32WinogradF23_4x4; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 78b4d388a..f357319dc 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { #if MEGDNN_AARCH64 benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); + benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", + "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); #else benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", "IM2COLMATMUL:ARMV7_F32:192", true); + benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", + "IM2COLMATMUL:ARMV7_F32:192", false); #endif } TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { #if MEGDNN_AARCH64 benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); + benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", + "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); #else benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", "IM2COLMATMUL:ARMV7_F32:192", true); + benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", + "IM2COLMATMUL:ARMV7_F32:192", false); #endif } diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 9e4d5f34a..7119c5ed5 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { checker_conv_bias_qint8x8x8( - get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, + true), + handle(), "S8_CONV_NCHW_NCHW44"); + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, + true), handle(), "S8_CONV_NCHW_NCHW44"); } -- GitLab