From 6c29548d20a20d83b5c1edfe000f6c109606f541 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Jun 2020 16:07:50 +0800 Subject: [PATCH] fix(dnn/arm): fix nchw_nchw44 dot stride1 support GitOrigin-RevId: c8d3d55b258e2a43c27b903808566f2ea1857842 --- dnn/src/arm_common/conv_bias/block_helper.h | 36 + .../conv_bias/fp32/f32_direct_nchw44_algo.cpp | 28 +- dnn/src/arm_common/conv_bias/int8/algos.h | 15 + .../int8/dot_direct_nchw_nchw44_algo.cpp | 321 ++++++++ .../int8/dot_direct_nchw_nchw44_kern.h | 779 ++++++++++++++++++ .../arm_common/conv_bias/intrinsic_helper.h | 369 +++++---- dnn/src/arm_common/conv_bias/opr_impl.cpp | 2 + dnn/src/arm_common/conv_bias/opr_impl.h | 1 + dnn/src/arm_common/neon_struct.h | 8 + dnn/src/arm_common/simd_macro/marm_neon.h | 46 +- dnn/test/arm_common/conv_bias.cpp | 17 +- .../arm_common/conv_bias_multi_thread.cpp | 10 + 12 files changed, 1432 insertions(+), 200 deletions(-) create mode 100644 dnn/src/arm_common/conv_bias/block_helper.h create mode 100644 dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp create mode 100644 dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h diff --git a/dnn/src/arm_common/conv_bias/block_helper.h b/dnn/src/arm_common/conv_bias/block_helper.h new file mode 100644 index 000000000..7ff14817e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/block_helper.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/arm_common/conv_bias/block_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/utils.h" +namespace megdnn { +namespace { +// block_helper is used to calculate oh block size +static inline int l2_block_helper(const int nthread, const int amount, + const int size_per_unit) { + constexpr int l2_cache_size = 256 * 1024; + const int block_per_thread = div_ceil(amount, nthread); + const int best_block = std::min( + amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); + const int max_block_num = div_ceil(block_per_thread, best_block); + const int min_block_num = std::max(max_block_num - 1, 1); + const int max_block = div_ceil(block_per_thread, max_block_num); + const int min_block = div_ceil(block_per_thread, min_block_num); + const int max_loss = std::abs(max_block_num * max_block - block_per_thread); + const int min_loss = std::abs(min_block_num * min_block - block_per_thread); + int block = max_loss > min_loss ? min_block : max_block; + return block; +} + +} // namespace +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index 8bad554f9..a024525e7 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -11,6 +11,7 @@ */ #include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" @@ -26,22 +27,7 @@ using conv_fun = std::function; MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) namespace { -// block_helper is used to calculate oh block size -static inline int block_helper(const int nthread, const int amount, - const int size_per_unit) { - constexpr int l2_cache_size = 256 * 1024; - const int block_per_thread = div_ceil(amount, nthread); - const int best_block = std::min( - amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); - const int max_block_num = div_ceil(block_per_thread, best_block); - const int min_block_num = std::max(max_block_num - 1, 1); - const int max_block = div_ceil(block_per_thread, max_block_num); - const int min_block = div_ceil(block_per_thread, min_block_num); - const int max_loss = std::abs(max_block_num * max_block - block_per_thread); - const int min_loss = std::abs(min_block_num * min_block - block_per_thread); - int block = max_loss > min_loss ? min_block : max_block; - return block; -} + static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, const int iw2) { // border_size is used to avoid read illegal memory @@ -60,7 +46,7 @@ static void get_rectified_size( ow2 = ow; constexpr int cacheline = 64 / sizeof(float); int block_oh = - block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); + l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); auto&& fm = param.filter_meta; const int stride_h = static_cast(fm.stride[0]); const int filter_h = static_cast(fm.spatial[0]); @@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle, const int group_id = ncb_index.ndrange_id[1]; constexpr int oc_idx = 0; int oc_block = oc; - int oh_block = block_helper(kern_param.nr_threads, oh2, - ic * iw * sizeof(float) * stride_h); + int oh_block = l2_block_helper(kern_param.nr_threads, oh2, + ic * iw * sizeof(float) * stride_h); const int oh_idx = ncb_index.ndrange_id[2]; const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); const int ih_real = oh_block_real * stride_h + fh - stride_h; @@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( int ic = param.filter_meta.icpg; int iw = param.isz[1]; int stride_h = param.filter_meta.stride[0]; - int oh_block = block_helper(param.nr_threads, oh, - ic * iw * sizeof(float) * stride_h); + int oh_block = l2_block_helper(param.nr_threads, oh, + ic * iw * sizeof(float) * stride_h); CpuNDRange ncb_range = {static_cast(batch), static_cast(group), static_cast(div_ceil(oh, oh_block))}; diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index b03551f9b..5b2629dd5 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -133,6 +133,21 @@ public: }; #if __ARM_FEATURE_DOTPROD + +class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; } + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { bool m_large_group; diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp new file mode 100644 index 000000000..9d3acdf7b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -0,0 +1,321 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ +#if __ARM_FEATURE_DOTPROD +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/block_helper.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" +#include "src/arm_common/elemwise_op.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_dot) +namespace { +static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, + const int iw2, + const int stride) { + //! border_size is used to avoid read illegal memory + constexpr int cacheline_size = 64; + constexpr int border_size = 2 * cacheline_size; + const int pack_iw_len = stride == 1 ? 4 : 1; + return round_up( + ic * ih2 * iw2 * pack_iw_len * (int)sizeof(int8_t) + border_size, + cacheline_size); +} +static inline size_t get_temp_bytes(const int iw, const int pw) { + //! border_size is used to avoid read illegal memory + constexpr int cacheline_size = 64; + constexpr int border_size = 1 * cacheline_size; + + return round_up(iw + pw * 2, cacheline_size) + border_size; +} +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2) { + auto&& fm = param.filter_meta; + const int stride_h = static_cast(fm.stride[0]); + const int filter_h = static_cast(fm.spatial[0]); + int ic = param.filter_meta.icpg; + int iw = param.isz[1]; + int oh = param.osz[0]; + int block_oh = l2_block_helper(param.nr_threads, oh, + ic * iw * sizeof(int8_t) * stride_h); + ih2 = block_oh * stride_h + filter_h - stride_h; + iw2 = iw + 2 * static_cast(fm.padding[1]); +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int ic = fm.icpg; + int fh = fm.spatial[0]; + int fw = fm.spatial[1]; + int iw = param.isz[1]; + int pw = param.filter_meta.padding[1]; + int stride_w = param.filter_meta.stride[1]; + int ih2, iw2; + get_rectified_size(param, ih2, iw2); + + size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w); + size_t weight_size = fm.group * fm.icpg * fm.ocpg * fh * round_up(fw, 4); + size_t temp_size = 0; + if (fm.stride[0] == 1) { + temp_size = get_temp_bytes(iw, pw); + } + return {nullptr, + {src_size * param.nr_threads, weight_size, + temp_size * param.nr_threads}}; +}; + +void do_weight_trans(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) { + const int ic = kern_param.filter_meta.icpg; + const int oc = kern_param.filter_meta.ocpg; + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + const int fw2 = round_up(fw, 4); + bundle.set(kern_param.workspace_ptr); + auto packed_weight = reinterpret_cast(bundle.get(1)); + auto origin_weight = kern_param.filter(); + pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh, + fw, fw2); +} + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange&, const CpuNDRange&) { + const int oh = kern_param.osz[0]; + const int ow = kern_param.osz[1]; + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + const int ic = kern_param.filter_meta.icpg; + const int oc = kern_param.filter_meta.ocpg; + const int ih = kern_param.isz[0]; + const int iw = kern_param.isz[1]; + const int stride_h = kern_param.filter_meta.stride[0]; + const int stride_w = kern_param.filter_meta.stride[1]; + const int ph = kern_param.filter_meta.padding[0]; + const int pw = kern_param.filter_meta.padding[1]; + int ih2 = 0; + int iw2 = 0; + get_rectified_size(kern_param, ih2, iw2); + bundle.set(kern_param.workspace_ptr); + + constexpr int pack_c = 4; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + constexpr int oc_idx = 0; + int oc_block = oc; + int oh_block = l2_block_helper(kern_param.nr_threads, oh, + ic * iw * sizeof(int8_t) * stride_h); + const int oh_idx = ncb_index.ndrange_id[2]; + const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); + const int ih_real = oh_block_real * stride_h + fh - stride_h; + const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); + const int src_bottom_pad = std::max( + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, + 0); + const int remain_right_pad = std::max(iw2 - iw - pw, 0); + const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; + const int8_t* origin_sptr = + static_cast( + kern_param.src(batch_id, group_id, 0, 1, 1)) + + src_offset; + const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w); + int8_t* sptr = reinterpret_cast(bundle.get(0)) + + ncb_index.thread_id * src_size; + int8_t* tmp_ptr = nullptr; + if (stride == 1) { + const size_t tmp_size = get_temp_bytes(iw, pw); + tmp_ptr = reinterpret_cast(bundle.get(2)) + + ncb_index.thread_id * tmp_size; + } + pack_src_int8_nchw_nchw44_dot( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw, tmp_ptr); + + const int8_t* fptr = + reinterpret_cast(bundle.get(1)) + oc_idx * fh * fw * ic; + int8_t* dst = kern_param.dst(batch_id, group_id) + + oh_idx * oh_block * ow * pack_c; + + const int bias_offset = oc_idx; + const int32_t* bptr = + kern_param.bias(batch_id, group_id) + bias_offset; + + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + Op op(scale_bias, scale_dst); + conv_direct_int8_nchw_nchw44_dot( + sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, + oh_block_real, ow, op); +} + +} // namespace + +bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto fh = fm.spatial[0]; + int oc = fm.ocpg; + int ic = fm.icpg; + bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); + bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && + (fh == 2 || fh == 3 || fh == 5 || fh == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[0] == 2); + bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; + return avaible; +} + +size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + // NOTE: remain_w is not used to gen hash of midout for compatible with +// shape runtime +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN(stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(stride, 5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(stride, 7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_assert(0); + break; + } + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + int oh = param.osz[0]; + int ic = param.filter_meta.icpg; + int iw = param.isz[1]; + int stride_h = param.filter_meta.stride[0]; + + int oh_block = l2_block_helper(param.nr_threads, oh, + ic * iw * sizeof(int8_t) * stride_h); + + CpuNDRange ncb_range = {static_cast(batch), + static_cast(group), + static_cast(div_ceil(oh, oh_block))}; + + auto do_trans_weight = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_weight_trans(bundle, kern_param, ncb_index, ncb_index.ndrange_id); + }; + ret_kerns.push_back({do_trans_weight, {1}}); + + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + return ret_kerns; +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h new file mode 100644 index 000000000..6bc57ebb8 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h @@ -0,0 +1,779 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2], weight[1][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], \ + src[1][(src_idx + step) / 4]); \ + c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2 + 1], weight[1][weight_idx], \ + src[1][(src_idx + step) / 4]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], \ + src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], \ + src[1][(src_idx + step) / 4]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ + c[1][step] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; +//! OCHelper is used to trans oc_block to row number of result regs +template +struct OCHelper { +public: + static const int val = -1; +}; + +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +#if MEGDNN_AARCH64 +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; +#endif +/** + * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel + * */ +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op); +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_hight = 2; + constexpr int filter_width = 4; + constexpr int weight_reg = 1; + constexpr int src_reg = 1; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_hight = 3; + constexpr int filter_width = 4; + constexpr int weight_reg = 1; + constexpr int src_reg = 1; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 2 + load_helper( + src, src_ptr + 2 * iw, stride); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_hight = 5; + constexpr int filter_width = 8; + constexpr int src_reg = 2; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper(src, src_ptr + step * iw, \ + stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + UNROLL_CALL_RAW(5, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += 5 * 32; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +/** + * oc = 8, ow = 8 + * dot 4 element, pad last filter and do twice dot every row filter, filter like + * below + * -------------------------- + * |x, x, x, x,| x, x, x, 0 | + * -------------------------- + **/ +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_hight = 7; + constexpr int filter_width = 8; + constexpr int src_reg = 2; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 1; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[2][src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper(src, src_ptr + step * iw, \ + stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + UNROLL_CALL_RAW(7, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += 7 * 32; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +////////////////////stride 1/////////////////// +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 2; + constexpr int filter_width = 4; + constexpr int weight_reg = 2; + constexpr int src_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw * pack_iw_len, 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 3; + constexpr int filter_width = 4; + constexpr int weight_reg = 3; + constexpr int src_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; + // row 0 + load_helper( + src, src_ptr + 0 * iw * pack_iw_len, 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 1 + load_helper( + src, src_ptr + 1 * iw * pack_iw_len, 0); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + // row 2 + load_helper( + src, src_ptr + 2 * iw * pack_iw_len, 0); + cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, + weight); + + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 5; + constexpr int filter_width = 8; + constexpr int src_reg = 3; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + + UNROLL_CALL_RAW(5, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +struct KerNeonDotXXs2Nchw44Int8 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int stride = 1; + constexpr int filter_hight = 7; + constexpr int filter_width = 8; + constexpr int src_reg = 3; + constexpr int weight_reg = 2; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int pack_iw_len = 4; + constexpr int simd_len = 16; + + const int ld_bias = oc_step; + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + int8x16_t src[src_reg]; + int8x16_t weight[c_dim][weight_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ + weight); \ + cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); + + UNROLL_CALL_RAW(7, cb); +#undef cb + src_ptr += ic_stride; + weight_ptr += filter_hight * filter_width * oc_step; + } + store_ocx_ow8_remain_static_dt( + c, op, dst_ptr, ld_dst_oc); + } +}; + +template +void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, + const int, const int pw, const int, + const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride, int8_t*) { + constexpr int ic_step = 1; + rep_step(ic_idx, ic, ic_step) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + memcpy(sptr_base + pw * ic_step, sptr, + sizeof(int8_t) * iw * ic_step); + sptr_base += iw2 * ic_step; + sptr += iw * ic_step; + } + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +template <> +void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, + const int8_t* sptr_origin, const int, + const int pw, const int, const int ih, + const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride, + int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, + 2, 3, 4, 5, 3, 4, 5, 6}; + uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); + + constexpr int iw_step = 16; + constexpr int pack_iw_len = 4; + const int iw_with_pad = iw + 2 * pw; + const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; + rep(ic_idx, ic) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * + pack_iw_len); + sptr_base += iw2 * pad_top * pack_iw_len; + rep(ih_idx, ih) { + memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); + memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); + for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { + int8x16_t src[4]; + int8x16_t dst[4]; + src[0] = vld1q_s8(temp_ptr + iw_idx); + src[1] = vld1q_s8(temp_ptr + iw_idx + 4); + src[2] = vld1q_s8(temp_ptr + iw_idx + 8); + src[3] = vld1q_s8(temp_ptr + iw_idx + 12); + dst[0] = vqtbl1q_s8(src[0], tbl_idx); + dst[1] = vqtbl1q_s8(src[1], tbl_idx); + dst[2] = vqtbl1q_s8(src[2], tbl_idx); + dst[3] = vqtbl1q_s8(src[3], tbl_idx); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); + vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); + } + for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { + *(sptr_base + iw_idx * pack_iw_len + 0) = + *(temp_ptr + iw_idx + 0); + *(sptr_base + iw_idx * pack_iw_len + 1) = + *(temp_ptr + iw_idx + 1); + *(sptr_base + iw_idx * pack_iw_len + 2) = + *(temp_ptr + iw_idx + 2); + *(sptr_base + iw_idx * pack_iw_len + 3) = + *(temp_ptr + iw_idx + 3); + } + sptr_base += iw2 * pack_iw_len; + sptr += iw; + } + sptr_base += iw2 * pad_bottom * pack_iw_len; + } +} + +static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, + const int8_t* src_ptr, + const int oc, const int ic, + const int fh, const int fw, + const int fw2) { + constexpr int oc_step = 4; + const int fw_remain = fw2 - fw; + const int dst_ic_stride = fh * fw2; + const int oc_step_stride = fh * fw2 * ic * oc_step; + static const uint8_t transpose_4x4_idx[16] = {0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); + rep_step(oc_idx, oc, oc_step) { + int32_t* dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + const int32_t* src_temp_ptr = reinterpret_cast( + src_ptr + oc_idx * ic * fh * fw); + // transpose ic and pad + rep(fh_idx, fh) { + rep(fw_idx, fw) { + rep(ic_idx, ic) { + *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; + src_temp_ptr++; + } + dst_temp_ptr++; + } + rep(ic_idx, ic) { + memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, + sizeof(int8_t) * oc_step * fw_remain); + } + dst_temp_ptr += fw_remain; + } + // transpose fw oc + int8_t* trans_dst_temp_ptr = + reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); + + rep_step(idx, oc_step_stride, 16) { + int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); + vst1q_s8(trans_dst_temp_ptr + idx, + vqtbl1q_s8(temp, tbl_transpose_4x4)); + } + } +} + +template +static void conv_direct_int8_nchw_nchw44_dot( + const int8_t* src, const int8_t* filter, const int32_t* bias, + int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op) { + MEGDNN_MARK_USED_VAR(temp); + constexpr int fh = filter_size; + constexpr int fw = (filter_size + 3) / 4 * 4; +#if MEGDNN_AARCH64 + constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 4; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = stride == 2 ? 1 : 4; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = + std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + kern_small_oc_remain = \ + KerNeonDotXXs2Nchw44Int8::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} + +} // namespace +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index 5909f7ad4..daa7568da 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -176,187 +176,202 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, StoreOcxOw4Remain::impl(c, op, dst_ptr, ld_dst_oc); } ////////////////////Store_OCX_OW8_Remain///////////////////////// -template +template struct StoreOcxOw8Remain { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc); + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); }; -template -struct StoreOcxOw8Remain<2, 0, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op({{c[0][6], c[0][7]}}, dst_ptr + 24); +template +struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); - op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); - op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; -template -struct StoreOcxOw8Remain<2, 8, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op({{c[0][6], c[0][7]}}, dst_ptr + 24); +template +struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); - op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); - op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; -template -struct StoreOcxOw8Remain<2, 7, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op(c[0][6], dst_ptr + 24); +template +struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[0][6], reinterpret_cast(dst_ptr + 24)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); - op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); - op(c[1][6], dst_ptr + ld_dst_oc + 24); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; -template -struct StoreOcxOw8Remain<2, 6, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); +template +struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); - op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + 16)); } }; -template -struct StoreOcxOw8Remain<2, 5, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op(c[0][4], dst_ptr + 16); +template +struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[0][4], reinterpret_cast(dst_ptr + 16)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); - op(c[1][4], dst_ptr + ld_dst_oc + 16); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + 16)); } }; -template -struct StoreOcxOw8Remain<2, 4, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); +template +struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; -template -struct StoreOcxOw8Remain<2, 3, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op(c[0][2], dst_ptr + 8); +template +struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); - op(c[1][2], dst_ptr + ld_dst_oc + 8); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; -template -struct StoreOcxOw8Remain<2, 2, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); +template +struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); } }; -template -struct StoreOcxOw8Remain<2, 1, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { - op(c[0][0], dst_ptr); - op(c[1][0], dst_ptr + ld_dst_oc); +template +struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op(c[0][0], reinterpret_cast(dst_ptr)); + op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); } }; -template -struct StoreOcxOw8Remain<1, 0, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op({{c[0][6], c[0][7]}}, dst_ptr + 24); +template +struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); } }; -template -struct StoreOcxOw8Remain<1, 8, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op({{c[0][6], c[0][7]}}, dst_ptr + 24); +template +struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); } }; -template -struct StoreOcxOw8Remain<1, 7, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); - op(c[0][6], dst_ptr + 24); +template +struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[0][6], reinterpret_cast(dst_ptr + 24)); } }; -template -struct StoreOcxOw8Remain<1, 6, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op({{c[0][4], c[0][5]}}, dst_ptr + 16); +template +struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); } }; -template -struct StoreOcxOw8Remain<1, 5, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); - op(c[0][4], dst_ptr + 16); +template +struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[0][4], reinterpret_cast(dst_ptr + 16)); } }; -template -struct StoreOcxOw8Remain<1, 4, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op({{c[0][2], c[0][3]}}, dst_ptr + 8); +template +struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); } }; -template -struct StoreOcxOw8Remain<1, 3, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); - op(c[0][2], dst_ptr + 8); +template +struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); } }; -template -struct StoreOcxOw8Remain<1, 2, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op({{c[0][0], c[0][1]}}, dst_ptr); +template +struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); } }; -template -struct StoreOcxOw8Remain<1, 1, Op, T> { - static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { - op(c[0][0], dst_ptr); +template +struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { + static void impl(T& c, const Op& op, T2 dst_ptr, int) { + op(c[0][0], reinterpret_cast(dst_ptr)); } }; -template -inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr, +template +inline void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { - StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); + StoreOcxOw8Remain::impl(c, op, dst_ptr, + ld_dst_oc); +} +template +inline void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr, + int ld_dst_oc) { + StoreOcxOw8Remain::impl(c, op, dst_ptr, + ld_dst_oc); } ////////////////////Store_OC8_OW8_Remain///////////////////////// @@ -522,68 +537,84 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, } } /////////////////////////init_ocx_ow8//////////////////// + +inline float32x4_t neon_vdupq_n(float val) { + return vdupq_n_f32(val); +} + +inline int32x4_t neon_vdupq_n(int val) { + return vdupq_n_s32(val); +} +inline float32x4_t neon_vld1q(const float* ptr) { + return vld1q_f32(ptr); +} + +inline int32x4_t neon_vld1q(const int* ptr) { + return vld1q_s32(ptr); +} + template struct InitOcxOw8 { - static void impl(T& c, T2 bias_ptr, int oc_step); + static void impl(T& c, const T2* bias_ptr, int oc_step); }; template struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { - static void impl(T& c, const float32_t*, int) { -#define BAIS_INIT(step) \ - c[0][step] = vdupq_n_f32(0); \ - c[1][step] = vdupq_n_f32(0); + static void impl(T& c, const T2*, int) { +#define BAIS_INIT(step) \ + c[0][step] = neon_vdupq_n(static_cast(0)); \ + c[1][step] = neon_vdupq_n(static_cast(0)); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { - static void impl(T& c, const float32_t*, int) { -#define BAIS_INIT(step) \ - c[0][step] = vdupq_n_f32(0); \ - c[1][step] = vdupq_n_f32(0); + static void impl(T& c, const T2*, int) { +#define BAIS_INIT(step) \ + c[0][step] = neon_vdupq_n(static_cast(0)); \ + c[1][step] = neon_vdupq_n(static_cast(0)); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int oc_step) { -#define BAIS_INIT(step) \ - c[0][step] = vld1q_f32(bias_ptr); \ - c[1][step] = vld1q_f32(bias_ptr + oc_step); + static void impl(T& c, const T2* bias_ptr, int oc_step) { +#define BAIS_INIT(step) \ + c[0][step] = neon_vld1q(bias_ptr); \ + c[1][step] = neon_vld1q(bias_ptr + oc_step); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int oc_step) { -#define BAIS_INIT(step) \ - c[0][step] = vld1q_f32(bias_ptr); \ - c[1][step] = vld1q_f32(bias_ptr + oc_step); + static void impl(T& c, const T2* bias_ptr, int oc_step) { +#define BAIS_INIT(step) \ + c[0][step] = neon_vld1q(bias_ptr); \ + c[1][step] = neon_vld1q(bias_ptr + oc_step); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int oc_step) { + static void impl(T& c, const T2* bias_ptr, int oc_step) { constexpr int simd_len = 4; -#define BAIS_INIT(step) \ - c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ - c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); +#define BAIS_INIT(step) \ + c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ + c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int oc_step) { + static void impl(T& c, const T2* bias_ptr, int oc_step) { constexpr int simd_len = 4; -#define BAIS_INIT(step) \ - c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ - c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); +#define BAIS_INIT(step) \ + c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ + c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } @@ -591,57 +622,57 @@ struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { template struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { - static void impl(T& c, const float32_t*, int) { -#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); + static void impl(T& c, const T2*, int) { +#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast(0)); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { - static void impl(T& c, const float32_t*, int) { -#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); + static void impl(T& c, const T2*, int) { +#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast(0)); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int) { -#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); + static void impl(T& c, const T2* bias_ptr, int) { +#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int) { -#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); + static void impl(T& c, const T2* bias_ptr, int) { +#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int) { + static void impl(T& c, const T2* bias_ptr, int) { constexpr int simd_len = 4; -#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); +#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT } }; template struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { - static void impl(T& c, const float32_t* bias_ptr, int) { + static void impl(T& c, const T2* bias_ptr, int) { constexpr int simd_len = 4; -#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); +#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT } }; template -inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) { +inline void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { InitOcxOw8::impl(c, bias_ptr, oc_step); } /////////////////////init_ocx_ow4///////////////////// diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index b902a8720..a1bab1b2d 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; #if __ARM_FEATURE_DOTPROD + AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44; AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; @@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { #if __ARM_FEATURE_DOTPROD + direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&ds8_direct_stride1_large_group); direct_algos.emplace_back(&ds8_direct_stride1_small_group); direct_algos.emplace_back(&ds8_direct_stride2_large_group); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 58db42a72..e1451126a 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -62,6 +62,7 @@ private: class AlgoFP16WinogradF23_8x8; #endif #if __ARM_FEATURE_DOTPROD + class AlgoDotS8DirectNCHWNCHW44; class AlgoDotS8DirectStride1; class AlgoDotS8DirectStride2; class AlgoDotU8DirectStride1; diff --git a/dnn/src/arm_common/neon_struct.h b/dnn/src/arm_common/neon_struct.h index 973fce7b7..6aaf14099 100644 --- a/dnn/src/arm_common/neon_struct.h +++ b/dnn/src/arm_common/neon_struct.h @@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 { return vfmaq_laneq_f32(a, b, v, lane); } }; +#if __ARM_FEATURE_DOTPROD +struct Vdotq_laneq_s32 { + template + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { + return vdotq_laneq_s32(a, b, v, lane); + } +}; +#endif } // namespace } // namespace megdnn diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index a923ff872..9443e67c3 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb); #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7::impl(vec) namespace { template -struct Vfmap_laneq_f32_armv7 { +struct Vfmaq_laneq_f32_armv7 { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); }; template <> -struct Vfmap_laneq_f32_armv7<0> { +struct Vfmaq_laneq_f32_armv7<0> { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); } }; template <> -struct Vfmap_laneq_f32_armv7<1> { +struct Vfmaq_laneq_f32_armv7<1> { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); } }; template <> -struct Vfmap_laneq_f32_armv7<2> { +struct Vfmaq_laneq_f32_armv7<2> { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); } }; template <> -struct Vfmap_laneq_f32_armv7<3> { +struct Vfmaq_laneq_f32_armv7<3> { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); } }; } // namespace #define vfmaq_laneq_f32(a, b, v, lane) \ - Vfmap_laneq_f32_armv7::impl(a, b, v) + Vfmaq_laneq_f32_armv7::impl(a, b, v) + +#if __ARM_FEATURE_DOTPROD +template +struct Vdotq_laneq_s32_armv7 { + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v); +}; +template <> +struct Vdotq_laneq_s32_armv7<0> { + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { + return vdotq_lane_s32(a, b, vget_low_s32(v), 0); + } +}; +template <> +struct Vdotq_laneq_s32_armv7<1> { + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { + return vdotq_lane_s32(a, b, vget_low_s32(v), 1); + } +}; +template <> +struct Vdotq_laneq_s32_armv7<2> { + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { + return vdotq_lane_s32(a, b, vget_high_s32(v), 0); + } +}; +template <> +struct Vdotq_laneq_s32_armv7<3> { + static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { + return vdotq_lane_s32(a, b, vget_high_f32(v), 1); + } +}; +#define vdotq_laneq_s32(a, b, v, lane) \ + Vdotq_laneq_s32_armv7::impl(a, b, v) + +#endif #endif diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index e8ddd8f00..55c01ec6f 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { .set_dtype(4, dtype::QuantizedS8(60.25)) .set_display(false); benchmarker_int.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384")); + conv_bias::ConvBiasAlgoChecker("IM2COLMATMUL:.+")); Benchmarker benchmarker_float(handle); benchmarker_float.set_display(false).set_times(RUNS); benchmarker_float.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "IM2COLMATMUL:AARCH64_F32K8X12X1:192")); + conv_bias::ConvBiasAlgoChecker("IM2COLMATMUL:.+")); Benchmarker benchmarker_nchw44(handle); if (is_fp32) { @@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { run(1, 256, 256, 14, 14, 3, 1, false); run(1, 512, 512, 7, 7, 3, 1, false); } else { + run(1, 1, 4, 112, 112, 2, 2, true); + run(1, 3, 32, 224, 224, 3, 2, true); + run(1, 3, 32, 224, 224, 5, 2, true); + run(1, 3, 64, 224, 224, 7, 2, true); + run(1, 1, 4, 112, 112, 2, 1, true); + run(1, 3, 32, 224, 224, 3, 1, true); + run(1, 3, 32, 224, 224, 5, 1, true); + run(1, 3, 64, 224, 224, 7, 1, true); + for (size_t stride : {1, 2}) { printf("stride %zu\n", stride); for (size_t filter_size : {2, 3, 5, 7}) { @@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { } TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { benchmark_convbias(handle(), true); + benchmark_convbias(handle(), false); } TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { benchmark_convbias(handle(), true); + benchmark_convbias(handle(), false); } #endif diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 58f5ee81a..0f41a0e30 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { /****************************dot qint8 direct*************************/ #if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, + true), + handle(), "ARMDOTS8_NCHW_NCHW44"); + checker_conv_bias_qint8x8x8( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, + true), + handle(), "ARMDOTS8_NCHW_NCHW44"); +} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( -- GitLab