diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h index f0a027917d97226764a1ea1a25eea911a67d4807..d47024cfc98959cfb57e77124b1d8d43092a5fe2 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h @@ -33,7 +33,7 @@ KERN(stride2, 5) #undef KERN -} // namesapce conv_bias +} // namespace channel_wise_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp index d82c18ffd805815364bde1f6ceb2371438161486..5ac8e85c793841685b3e67e283145428752360de 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp @@ -10,16 +10,15 @@ */ #include "src/arm_common/conv_bias/int8x8x16/algos.h" +#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h" +#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h" #include "src/arm_common/conv_bias/int8x8x16/conv_direct.h" #include "src/arm_common/conv_bias/int8x8x16/conv_stride2.h" #include "midout.h" -#include "src/common/opr_delegate.h" + MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_kimpl) -#include -#include -#include using namespace megdnn; using namespace arm_common; @@ -550,4 +549,70 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( return {{kern, {group, 1_z, 1_z}}}; } +/* =====================8int8x8x16 channel_wise_nchw44 stride1 stride2 algo ===================== */ +bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool avaible = + //! src and filter are int8, dst is int16 + (param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int16) && + fm.format == param::Convolution::Format::NCHW44 && + param.bias_mode != megdnn::BiasMode::BIAS && + param.nonlineMode == megdnn::NonlineMode::IDENTITY && + !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && + (fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[0] == 2)) && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && + fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + return avaible; +} + +size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( + const NCBKernSizeParam& param) const { + size_t stride_h = param.filter_meta.stride[0]; + size_t stride_w = param.filter_meta.stride[1]; + megdnn_assert(stride_h == stride_w); + if (stride_h == 1) { + return channel_wise_nchw44_8x8x16::stride1::get_bundle(param) + .total_size_in_bytes(); + } else if (stride_h == 2) { + return channel_wise_nchw44_8x8x16::stride2::get_bundle(param) + .total_size_in_bytes(); + } else { + return 0; + } +} + +SmallVector +ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns( + const NCBKernSizeParam& param) const { + size_t stride_h = param.filter_meta.stride[0]; + size_t stride_w = param.filter_meta.stride[1]; + if (stride_h == stride_w && stride_h == 1) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv( + "AlgoS8x8x16ChanWiseStride1Stride2NCHW44_dispatch_kerns"_hash)) { + return channel_wise_nchw44_8x8x16::stride1::get_kimpls(param); + } + MIDOUT_END(); + return {}; + } else if (stride_h == stride_w && stride_h == 2) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv( + "AlgoS8x8x16ChanWiseStride2NCHW44_dispatch_kerns"_hash)) { + return channel_wise_nchw44_8x8x16::stride2::get_kimpls(param); + } + MIDOUT_END(); + return {}; + } else { + return {}; + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index 198fa891d2ac9ace92fe50913af81d5186f07d9c..ffed95c57b1e790e4d80847a77b2210beacc634f 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -72,6 +72,18 @@ public: const NCBKernSizeParam& param) const override; }; +class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } + bool usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace( + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; +}; + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ddf1c2e3efac371b057a00beda004842ce5cd1b7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44_8x8x16 { + +#define KERN(stride, i) \ + template \ + void direct_##stride##_##i##x##i##_int8x8x16( \ + const int8_t* src, const int8_t* filter, const int16_t* bias, \ + void* dst, const size_t IH, const size_t IW, const size_t OH, \ + const size_t OW); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) + +#undef KERN + +} // namespace channel_wise_nchw44_8x8x16 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5cd5103dcac93f68f395a5d25526fbe83949244f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp @@ -0,0 +1,1924 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; + +#define INIT_SUM() \ + int16x8_t init_sum; \ + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { \ + int16x4_t tmpsum = vld1_s16(bptr); \ + init_sum = vcombine_s16(tmpsum, tmpsum); \ + } else { \ + init_sum = vdupq_n_s16(0); \ + } + +#define STORE_1_LINE_RESULT(dst, oh, ow, OW, sum) \ + do { \ + dt_int16* dptr = \ + reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ + vst1q_s16(dptr, sum[0]); \ + vst1q_s16(dptr + 8, sum[1]); \ + vst1q_s16(dptr + 16, sum[2]); \ + vst1q_s16(dptr + 24, sum[3]); \ + } while (0); + +#define STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum) \ + do { \ + dt_int16* dptr = \ + reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ + vst1q_s16(dptr, sum[0]); \ + vst1q_s16(dptr + 8, sum[1]); \ + } while (0); + +#define STORE_REMAIN(dst, oh, ow, OW, sum, remain) \ + do { \ + dt_int16* dptr = \ + reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + vst1_s16(dptr, vget_low_s16(sum[0])); \ + } else if (remain == 2) { \ + vst1q_s16(dptr, sum[0]); \ + } else if (remain == 3) { \ + vst1q_s16(dptr, sum[0]); \ + vst1_s16(dptr + 8, vget_low_s16(sum[1])); \ + } \ + } while (0); + +template +void channel_wise_nchw44_8x8x16::direct_stride1_2x2_int8x8x16( + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + INIT_SUM(); + const int* fptr = reinterpret_cast(filter); + int8x8_t kern[4]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i)); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +#define LOAD_SRC(_sptr, _src) \ + _src[0] = vld1q_s8(_sptr); \ + _src[1] = vld1q_s8(_sptr + 16); \ + _src[1] = vextq_s8(_src[0], _src[1], 4); + +#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); + +#define LOAD_SRC_8(_sptr, _src) \ + _src[0] = vld1q_s8(_sptr); \ + _src[2] = vld1q_s8(_sptr + 16); \ + _src[3] = vld1q_s8(_sptr + 32); \ + _src[1] = vextq_s8(_src[0], _src[2], 4); \ + _src[3] = vextq_s8(_src[2], _src[3], 4); + +#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1)\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[2]),kern[_kid0]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[2]),kern[_kid0]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[3]),kern[_kid1]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[3]),kern[_kid1]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + + int16x8_t sum[2][4]; + int8x16_t src[2][4]; + +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + LOAD_SRC_8(sptr0, src[0]); + LOAD_SRC_8(sptr1, src[1]); + + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 0, 1); + LOAD_SRC_8(sptr2, src[0]); + CALC_ONE_LINE_8_RESULT(sum[0], src[1], 2, 3); + CALC_ONE_LINE_8_RESULT(sum[1], src[1], 0, 1); + CALC_ONE_LINE_8_RESULT(sum[1], src[0], 2, 3); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; + int8x16_t src[2][2]; + +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + LOAD_SRC(sptr0, src[0]); + LOAD_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1); + LOAD_SRC(sptr2, src[0]); + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 2, 3); + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 2, 3); + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; + int8x16_t src[2][2]; + +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + LOAD_SRC(sptr0, src[0]); + LOAD_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1); + LOAD_SRC(sptr2, src[0]); + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 2, 3); + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 2, 3); + STORE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + + int16x8_t sum[4]; + int8x16_t src[2][4]; +#define cb(i) sum[i] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + LOAD_SRC_8(sptr0, src[0]); + LOAD_SRC_8(sptr1, src[1]); + + CALC_ONE_LINE_8_RESULT(sum, src[0], 0, 1); + CALC_ONE_LINE_8_RESULT(sum, src[1], 2, 3); + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum); + } + + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + + int16x8_t sum[2]; + int8x16_t src[2][2]; + sum[0] = init_sum; + sum[1] = init_sum; + + LOAD_SRC(sptr0, src[0]); + LOAD_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1); + CALC_ONE_LINE_4_RESULT(sum, src[1], 2, 3); + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); + + } + + if (ow < OW) { + size_t iw = ow; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int16x8_t sum[2]; + int8x16_t src[2][2]; + sum[0] = init_sum; + sum[1] = init_sum; + + LOAD_SRC(sptr0, src[0]); + LOAD_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1); + CALC_ONE_LINE_4_RESULT(sum, src[1], 2, 3); + STORE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +} +#undef CALC_ONE_LINE_4_RESULT +#undef CALC_ONE_LINE_8_RESULT +#undef LOAD_SRC +#undef LOAD_SRC_8 + +template +void channel_wise_nchw44_8x8x16::direct_stride1_3x3_int8x8x16( + const int8_t* sptr, const int8_t* fptr, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + INIT_SUM(); + const int* filter = reinterpret_cast(fptr); + int8x8_t kern[9]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(filter + i)); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + +#define LOAD_6_SRC(src, sptr0) \ + src[0] = vld1q_s8(sptr0); \ + src[1] = vld1q_s8(sptr0 + 16); \ + tmp_src0 = vld1q_s8(sptr0 + 32); \ + src[2] = vextq_s8(src[0], src[1], 4); \ + src[3] = vextq_s8(src[1], tmp_src0, 4); \ + src[4] = vextq_s8(src[0], src[1], 8); \ + src[5] = vextq_s8(src[1], tmp_src0, 8); + +#define LOAD_3_SRC(sptr, src) \ + src[0] = vld1q_s8(sptr); \ + src[2] = vld1q_s8(sptr + 16); \ + src[1] = vextq_s8(src[0], src[2], 4); \ + src[2] = vextq_s8(src[0], src[2], 8); + +#define CALC_ONE_LINE(_src, _kern0, _kern1, _kern2, _sum) \ + _sum[0] = vmlal_s8(_sum[0], _kern0, vget_low_s8(_src[0])); \ + _sum[1] = vmlal_s8(_sum[1], _kern0, vget_high_s8(_src[0])); \ + _sum[0] = vmlal_s8(_sum[0], _kern1, vget_low_s8(_src[1])); \ + _sum[1] = vmlal_s8(_sum[1], _kern1, vget_high_s8(_src[1])); \ + _sum[0] = vmlal_s8(_sum[0], _kern2, vget_low_s8(_src[2])); \ + _sum[1] = vmlal_s8(_sum[1], _kern2, vget_high_s8(_src[2])); + +#define CALC_ONE(_src, _i, _j, _kern, _sum) \ + _sum[0] = vmlal_s8(_sum[0], _kern, vget_low_s8(_src[_i])); \ + _sum[1] = vmlal_s8(_sum[1], _kern, vget_high_s8(_src[_i])); \ + _sum[2] = vmlal_s8(_sum[2], _kern, vget_low_s8(_src[_j])); \ + _sum[3] = vmlal_s8(_sum[3], _kern, vget_high_s8(_src[_j])); + + size_t oh = 0_z; + for (; oh + 3 <= OH; oh += 3) { + size_t ih = oh; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum0[4], sum1[4], sum2[4]; +#define cb(j) \ + sum0[j] = init_sum; \ + sum1[j] = init_sum; \ + sum2[j] = init_sum; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + int8x16_t src[2][6]; + int8x16_t tmp_src0; + LOAD_6_SRC(src[0], sptr0); //! line0 + LOAD_6_SRC(src[1], sptr1); //! line1 + CALC_ONE(src[0], 0, 1, kern[0], sum0); + CALC_ONE(src[0], 2, 3, kern[1], sum0); + CALC_ONE(src[0], 4, 5, kern[2], sum0); + CALC_ONE(src[1], 0, 1, kern[3], sum0); + CALC_ONE(src[1], 2, 3, kern[4], sum0); + CALC_ONE(src[1], 4, 5, kern[5], sum0); + + LOAD_6_SRC(src[0], sptr2); //! line2 + CALC_ONE(src[0], 0, 1, kern[6], sum0); + CALC_ONE(src[0], 2, 3, kern[7], sum0); + CALC_ONE(src[0], 4, 5, kern[8], sum0); + + CALC_ONE(src[1], 0, 1, kern[0], sum1); + CALC_ONE(src[1], 2, 3, kern[1], sum1); + CALC_ONE(src[1], 4, 5, kern[2], sum1); + + CALC_ONE(src[0], 0, 1, kern[3], sum1); + CALC_ONE(src[0], 2, 3, kern[4], sum1); + CALC_ONE(src[0], 4, 5, kern[5], sum1); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum0) + LOAD_6_SRC(src[1], sptr3); //! line3 + + CALC_ONE(src[1], 0, 1, kern[6], sum1); + CALC_ONE(src[1], 2, 3, kern[7], sum1); + CALC_ONE(src[1], 4, 5, kern[8], sum1); + + CALC_ONE(src[0], 0, 1, kern[0], sum2); + CALC_ONE(src[0], 2, 3, kern[1], sum2); + CALC_ONE(src[0], 4, 5, kern[2], sum2); + + CALC_ONE(src[1], 0, 1, kern[3], sum2); + CALC_ONE(src[1], 2, 3, kern[4], sum2); + CALC_ONE(src[1], 4, 5, kern[5], sum2); + LOAD_6_SRC(src[0], sptr4); //! line4 + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum1) + + CALC_ONE(src[0], 0, 1, kern[6], sum2); + CALC_ONE(src[0], 2, 3, kern[7], sum2); + CALC_ONE(src[0], 4, 5, kern[8], sum2); + STORE_1_LINE_RESULT(dst, (oh + 2), ow, OW, sum2) + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4; + + int16x8_t sum0[2], sum1[2], sum2[2]; +#define cb(j) \ + sum0[j] = init_sum; \ + sum1[j] = init_sum; \ + sum2[j] = init_sum; + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + + int8x16_t src[2][3]; + + LOAD_3_SRC(sptr0,src[0]); + LOAD_3_SRC(sptr1,src[1]); + + CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum0);//line0 + CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1 + CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1 + + LOAD_3_SRC(sptr2,src[0]);//line2 + + CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2 + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum0) + + CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2 + CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2 + LOAD_3_SRC(sptr3,src[1]);//line3 + + CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3 + STORE_1_LINE_4_RESULT(dst, (oh+1), ow, OW, sum1) + CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3 + LOAD_3_SRC(sptr4,src[0]); + CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4 + STORE_1_LINE_4_RESULT(dst, (oh+2), ow, OW, sum2) + } + if (ow < OW) { + size_t iw = ow; + size_t remain = OW - ow; + + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum0[2], sum1[2], sum2[2]; + int8x16_t src[2][3]; +#define cb(j) \ + sum0[j] = init_sum; \ + sum1[j] = init_sum; \ + sum2[j] = init_sum; + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + + LOAD_3_SRC(sptr0,src[0]);//line2 + LOAD_3_SRC(sptr1,src[1]);//line2 + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum0); // line0 + CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1 + CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1 + + LOAD_3_SRC(sptr2,src[0]);//line2 + + CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2 + STORE_REMAIN(dst, (oh+0), ow, OW, sum0,remain) + + CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2 + CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2 + LOAD_3_SRC(sptr3,src[1]);//line3 + + CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3 + STORE_REMAIN(dst, (oh+1), ow, OW, sum1,remain) + CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3 + LOAD_3_SRC(sptr4,src[0]); + CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4 + STORE_REMAIN(dst, (oh+2), ow, OW, sum2, remain) + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + int16x8_t sum0[4]; + int8x16_t src[2][6]; + int8x16_t tmp_src0; + + sum0[0] = init_sum; + sum0[1] = init_sum; + sum0[2] = init_sum; + sum0[3] = init_sum; + + LOAD_6_SRC(src[0], sptr0); //! line0 + LOAD_6_SRC(src[1], sptr1); //! line1 + CALC_ONE(src[0], 0, 1, kern[0], sum0); + CALC_ONE(src[0], 2, 3, kern[1], sum0); + CALC_ONE(src[0], 4, 5, kern[2], sum0); + CALC_ONE(src[1], 0, 1, kern[3], sum0); + CALC_ONE(src[1], 2, 3, kern[4], sum0); + CALC_ONE(src[1], 4, 5, kern[5], sum0); + LOAD_6_SRC(src[0], sptr2); //! line2 + CALC_ONE(src[0], 0, 1, kern[6], sum0); + CALC_ONE(src[0], 2, 3, kern[7], sum0); + CALC_ONE(src[0], 4, 5, kern[8], sum0); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum0); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + + int16x8_t sum00[2]; + int8x16_t src[2][3]; + + sum00[0] = init_sum; + sum00[1] = init_sum; + + LOAD_3_SRC(sptr0, src[0]); + LOAD_3_SRC(sptr1, src[1]); + + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum00); // line0 + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum00); // line1 + + LOAD_3_SRC(sptr2, src[0]); // line2 + + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum00); // line2 + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum00) + } + if (ow < OW) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; + + int16x8_t sum00[2]; + int8x16_t src[2][3]; + + sum00[0] = init_sum; + sum00[1] = init_sum; + + LOAD_3_SRC(sptr0, src[0]); + LOAD_3_SRC(sptr1, src[1]); + + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum00); // line0 + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum00); // line1 + + LOAD_3_SRC(sptr2, src[0]); // line2 + + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum00); // line2 + STORE_REMAIN(dst, oh, ow, OW, sum00,(OW-ow)) + } + } +#undef LOAD_3_SRC +#undef LOAD_6_SRC +#undef CALC_ONE +#undef CALC_ONE_LINE +} + +template +void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( + const int8_t* sptr, const int8_t* fptr, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + INIT_SUM(); + const int* filter = reinterpret_cast(fptr); + int8x8_t kern[25]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(filter + i)); + UNROLL_CALL_NOWRAPPER(25, cb); +#undef cb +#define LOAD_1_LINE_SRC(sptr, src) \ + src[0] = vld1q_s8(sptr); \ + src[4] = vld1q_s8(sptr + 16); \ + src[1] = vextq_s8(src[0], src[4], 4); \ + src[2] = vextq_s8(src[0], src[4], 8); \ + src[3] = vextq_s8(src[0], src[4], 12); + +#define LOAD_1_LINE_10_SRC(sptr, src) \ + src[0] = vld1q_s8(sptr); \ + src[4] = vld1q_s8(sptr + 16); \ + src[8] = vld1q_s8(sptr + 32); \ + src[1] = vextq_s8(src[0], src[4], 4); \ + src[2] = vextq_s8(src[0], src[4], 8); \ + src[3] = vextq_s8(src[0], src[4], 12); \ + src[5] = vextq_s8(src[4], src[8], 4); \ + src[6] = vextq_s8(src[4], src[8], 8); \ + src[7] = vextq_s8(src[4], src[8], 12); + + +#define CALC_ONE_LINE_4_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]); + +#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[4]),kern[_kid0]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[4]),kern[_kid0]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[5]),kern[_kid1]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[5]),kern[_kid1]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[6]),kern[_kid2]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[6]),kern[_kid2]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[7]),kern[_kid3]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[7]),kern[_kid3]);\ + _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\ + _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]);\ + _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[8]),kern[_kid4]);\ + _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[8]),kern[_kid4]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + const int8_t* __restrict sptr5 = sptr4 + IW * 4; + + int16x8_t sum[2][4]; + int8x16_t src[2][9]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + LOAD_1_LINE_10_SRC(sptr0,src[0]); + LOAD_1_LINE_10_SRC(sptr1,src[1]); + + CALC_ONE_LINE_8_RESULT(sum[0],src[0],0,1,2,3,4); + LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_8_RESULT(sum[0],src[1],5,6,7,8,9);//line1 + CALC_ONE_LINE_8_RESULT(sum[1],src[1],0,1,2,3,4);//line1 + LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_8_RESULT(sum[0],src[0],10,11,12,13,14);//line2 + CALC_ONE_LINE_8_RESULT(sum[1],src[0],5,6,7,8,9);//line2 + LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_8_RESULT(sum[0],src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_8_RESULT(sum[1],src[1],10,11,12,13,14);//line3 + LOAD_1_LINE_10_SRC(sptr5,src[1]);//line5 + CALC_ONE_LINE_8_RESULT(sum[0],src[0],20,21,22,23,24);//line4 + CALC_ONE_LINE_8_RESULT(sum[1],src[0],15,16,17,18,19);//line3 + CALC_ONE_LINE_8_RESULT(sum[1],src[1],20,21,22,23,24);//line3 + + STORE_1_LINE_RESULT(dst,oh,ow,OW,sum[0]); + STORE_1_LINE_RESULT(dst,(oh+1),ow,OW,sum[1]); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + const int8_t* __restrict sptr5 = sptr4 + IW * 4; + + int16x8_t sum[2][2]; + int8x16_t src[2][5]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + + + LOAD_1_LINE_SRC(sptr0,src[0]); + LOAD_1_LINE_SRC(sptr1,src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4); + LOAD_1_LINE_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1 + LOAD_1_LINE_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2 + CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2 + LOAD_1_LINE_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3 + LOAD_1_LINE_SRC(sptr5,src[1]);//line5 + CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4 + CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3 + + STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum[0]); + STORE_1_LINE_4_RESULT(dst,(oh+1),ow,OW,sum[1]); + } + if (ow < OW) { + size_t remain = OW - ow; + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + const int8_t* __restrict sptr5 = sptr4 + IW * 4; + + int16x8_t sum[2][2]; + int8x16_t src[2][5]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + LOAD_1_LINE_SRC(sptr0,src[0]); + LOAD_1_LINE_SRC(sptr1,src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4); + LOAD_1_LINE_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1 + LOAD_1_LINE_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2 + CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2 + LOAD_1_LINE_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3 + LOAD_1_LINE_SRC(sptr5,src[1]);//line5 + CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4 + CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3 + + STORE_REMAIN(dst,oh,ow,OW,sum[0],remain); + STORE_REMAIN(dst,(oh+1),ow,OW,sum[1],remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + + int16x8_t sum[4]; + int8x16_t src[2][9]; +#define cb(j) sum[j] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + LOAD_1_LINE_10_SRC(sptr0,src[0]); + LOAD_1_LINE_10_SRC(sptr1,src[1]); + + CALC_ONE_LINE_8_RESULT(sum,src[0],0,1,2,3,4); + LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_8_RESULT(sum,src[1],5,6,7,8,9);//line1 + LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_8_RESULT(sum,src[0],10,11,12,13,14);//line2 + LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_8_RESULT(sum,src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_8_RESULT(sum,src[0],20,21,22,23,24);//line4 + + STORE_1_LINE_RESULT(dst,oh,ow,OW,sum); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + + int16x8_t sum[2]; + int8x16_t src[2][5]; + sum[0]=init_sum; + sum[1]=init_sum; + + + LOAD_1_LINE_SRC(sptr0,src[0]); + LOAD_1_LINE_SRC(sptr1,src[1]); + + CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4); + LOAD_1_LINE_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1 + LOAD_1_LINE_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2 + LOAD_1_LINE_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4 + + STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum); + } + if (ow < OW) { + size_t remain = OW - ow; + size_t iw = ow; + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = sptr0 + IW * 4; + const int8_t* __restrict sptr2 = sptr1 + IW * 4; + const int8_t* __restrict sptr3 = sptr2 + IW * 4; + const int8_t* __restrict sptr4 = sptr3 + IW * 4; + int16x8_t sum[2]; + int8x16_t src[2][5]; + sum[0]=init_sum; + sum[1]=init_sum; + + LOAD_1_LINE_SRC(sptr0,src[0]); + LOAD_1_LINE_SRC(sptr1,src[1]); + + CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4); + LOAD_1_LINE_SRC(sptr2,src[0]);//line2 + CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1 + LOAD_1_LINE_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2 + LOAD_1_LINE_SRC(sptr4,src[0]);//line4 + CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3 + CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4 + STORE_REMAIN(dst,oh,ow,OW,sum,remain); + } + } +#undef LOAD_1_LINE_SRC +#undef LOAD_1_LINE_10_SRC +#undef CALC_ONE_LINE_4_RESULT +#undef CALC_ONE_LINE_8_RESULT +} + +template +void channel_wise_nchw44_8x8x16::direct_stride2_2x2_int8x8x16( + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + INIT_SUM(); +const int* fptr = reinterpret_cast(filter); + int8x8_t kern[4]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i)); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##2), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##2), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##3), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##3), kern[_kid1]); + +#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + + int16x8_t sum[2][4]; +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +#define cb(i)\ +const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(4,cb) +#undef cb + +#define cb(i)\ + int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i);\ + int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i+8); + + UNROLL_CALL_NOWRAPPER(4,cb) +#undef cb + +#define cb(i)\ + int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i##_00.val[0]);\ + int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i##_00.val[1]);\ + int8x16_t row##i##2 =vreinterpretq_s8_s32(tmp_row##i##_01.val[0]);\ + int8x16_t row##i##3 =vreinterpretq_s8_s32(tmp_row##i##_01.val[1]); + + UNROLL_CALL_NOWRAPPER(4,cb) +#undef cb + + CALC_ONE_LINE_8_RESULT(sum[0],0,0,1); + CALC_ONE_LINE_8_RESULT(sum[0],1,2,3); + CALC_ONE_LINE_8_RESULT(sum[1],2,0,1); + CALC_ONE_LINE_8_RESULT(sum[1],3,2,3); + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i)\ + int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i)\ + int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ + int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\ + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum[0],0,0,1); + CALC_ONE_LINE_4_RESULT(sum[0],1,2,3); + + CALC_ONE_LINE_4_RESULT(sum[1],2,0,1); + CALC_ONE_LINE_4_RESULT(sum[1],3,2,3); + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + + int16x8_t sum[2][2]; +#define cb(i) \ + sum[0][i] = init_sum; \ + sum[1][i] = init_sum; + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i)\ + int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ + int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\ + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum[0],0,0,1); + CALC_ONE_LINE_4_RESULT(sum[0],1,2,3); + + CALC_ONE_LINE_4_RESULT(sum[1],2,0,1); + CALC_ONE_LINE_4_RESULT(sum[1],3,2,3); + + STORE_REMAIN(dst, (oh+0), ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh+1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(2, cb) +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \ + int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i + 8); + + UNROLL_CALL_NOWRAPPER(2, cb) +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ + int8x16_t row##i##2 = vreinterpretq_s8_s32(tmp_row##i##_01.val[0]); \ + int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_01.val[1]); + + UNROLL_CALL_NOWRAPPER(2, cb) +#undef cb + + CALC_ONE_LINE_8_RESULT(sum, 0, 0, 1); + CALC_ONE_LINE_8_RESULT(sum, 1, 2, 3); + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int16x8_t sum[2]={init_sum,init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i)\ + int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ + int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum,0,0,1); + CALC_ONE_LINE_4_RESULT(sum,1,2,3); + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + int16x8_t sum[2]={init_sum,init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i)\ + int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ + int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum,0,0,1); + CALC_ONE_LINE_4_RESULT(sum,1,2,3); + STORE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +#undef CALC_ONE_LINE_4_RESULT +#undef CALC_ONE_LINE_8_RESULT +} + +template +void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + + const int16_t* __restrict bptr = bias; + INIT_SUM(); + + const int* fptr = reinterpret_cast(filter); + int8x8_t kern[9]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i)); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb +#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1, _kid2) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##3), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##3), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##4), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##4), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##2), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##2), kern[_kid2]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##5), kern[_kid2]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##5), kern[_kid2]); + +#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1, _kid2) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##2), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##2), kern[_kid2]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum[2][4]; + +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \ + int32x4x2_t tmp_row##i##_03 = vld2q_s32(tmp_sptr##i + 8); \ + int32x4_t tmp_row##i = vld1q_s32(tmp_sptr##i + 16); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ + int8x16_t row##i##2 = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \ + int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \ + int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \ + int8x16_t row##i##5 = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); + + UNROLL_CALL_NOWRAPPER(5, cb) +#undef cb + + CALC_ONE_LINE_8_RESULT(sum[0], 0, 0, 1, 2); + CALC_ONE_LINE_8_RESULT(sum[0], 1, 3, 4, 5); + CALC_ONE_LINE_8_RESULT(sum[0], 2, 6, 7, 8); + CALC_ONE_LINE_8_RESULT(sum[1], 2, 0, 1, 2); + CALC_ONE_LINE_8_RESULT(sum[1], 3, 3, 4, 5); + CALC_ONE_LINE_8_RESULT(sum[1], 4, 6, 7, 8); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + + int16x8_t sum[2][2]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \ + int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ + int8x16_t row##i##2 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum[0], 1, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum[0], 2, 6, 7, 8); + CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum[1], 3, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum[1], 4, 6, 7, 8); + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + + int16x8_t sum[2][2]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \ + int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ + int8x16_t row##i##2 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); + + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum[0], 1, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum[0], 2, 6, 7, 8); + CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum[1], 3, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum[1], 4, 6, 7, 8); + + STORE_REMAIN(dst, (oh + 0), ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; +#if MEGDNN_AARCH64 + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + + int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \ + int32x4x2_t tmp_row##i##_03 = vld2q_s32(tmp_sptr##i + 8); \ + int32x4_t tmp_row##i = vld1q_s32(tmp_sptr##i + 16); + + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ + int8x16_t row##i##2 = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \ + int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \ + int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \ + int8x16_t row##i##5 = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + + CALC_ONE_LINE_8_RESULT(sum, 0, 0, 1, 2); + CALC_ONE_LINE_8_RESULT(sum, 1, 3, 4, 5); + CALC_ONE_LINE_8_RESULT(sum, 2, 6, 7, 8); + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum); + } +#endif + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int16x8_t sum[2] = {init_sum, init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \ + int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ + int8x16_t row##i##2 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + + CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum, 1, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum, 2, 6, 7, 8); + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + int16x8_t sum[2] = {init_sum, init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + +#define cb(i) \ + int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \ + int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ + int8x16_t row##i##2 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); + + UNROLL_CALL_NOWRAPPER(3, cb) +#undef cb + + CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1, 2); + CALC_ONE_LINE_4_RESULT(sum, 1, 3, 4, 5); + CALC_ONE_LINE_4_RESULT(sum, 2, 6, 7, 8); + + STORE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +#undef CALC_ONE_LINE_4_RESULT +#undef CALC_ONE_LINE_8_RESULT +#undef LOAD_5_SRC +} + +#if MEGDNN_AARCH64 + +template +void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + INIT_SUM(); + + const int* fptr = reinterpret_cast(filter); + int8x8_t kern[25]; +#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i)); + UNROLL_CALL_NOWRAPPER(25, cb); +#undef cb + +#define LOAD_5_SRC(_src, _id) \ + do { \ + int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ + int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 2); \ + int32x4_t tmp_row = vld1q_s32(tmp_sptr##_id + 10); \ + _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ + _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ + _src[2] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ + _src[3] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ + _src[4] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_23.val[0], tmp_row, 1)); \ + } while (0); + +#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \ + _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); + +#define LOAD_10_SRC(_src, _id) \ + do { \ + int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ + int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 8); \ + int32x4x2_t tmp_row = vld2q_s32(tmp_sptr##_id + 16); \ + _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ + _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ + _src[2] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 1)); \ + _src[3] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[1], tmp_row_23.val[1], 1)); \ + _src[4] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 2)); \ + _src[5] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ + _src[6] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ + _src[7] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 1)); \ + _src[8] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_23.val[1], tmp_row.val[1], 1)); \ + _src[9] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 2)); \ + } while (0); + +#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \ + _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[5]), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[5]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[6]), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[6]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[7]), kern[_kid2]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[7]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[8]), kern[_kid3]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[8]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[9]), kern[_kid4]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[9]), kern[_kid4]); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int16x8_t sum[2][4]; + int8x16_t src[3][10]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(7, cb); +#undef cb + + LOAD_10_SRC(src[0], 0); // line0 + LOAD_10_SRC(src[1], 1); // line1 + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_10_SRC(src[2], 2); // line2 + CALC_ONE_LINE_8_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); + LOAD_10_SRC(src[0], 3); // line3 + CALC_ONE_LINE_8_RESULT(sum[0], src[2], 10, 11, 12, 13, 14); + CALC_ONE_LINE_8_RESULT(sum[1], src[2], 0, 1, 2, 3, 4); + LOAD_10_SRC(src[1], 4); // line4 + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_8_RESULT(sum[0], src[1], 20, 21, 22, 23, 24); + LOAD_10_SRC(src[2], 5); // line5 + CALC_ONE_LINE_8_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); + CALC_ONE_LINE_8_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); + LOAD_10_SRC(src[0], 6); // line6 + CALC_ONE_LINE_8_RESULT(sum[1], src[2], 15, 16, 17, 18, 19); + CALC_ONE_LINE_8_RESULT(sum[1], src[0], 20, 21, 22, 23, 24); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; + int8x16_t src[3][5]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(7, cb); +#undef cb + + LOAD_5_SRC(src[0], 0); // line0 + LOAD_5_SRC(src[1], 1); // line1 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[2], 2); // line2 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[0], 3); // line3 + CALC_ONE_LINE_4_RESULT(sum[0], src[2], 10, 11, 12, 13, 14); + CALC_ONE_LINE_4_RESULT(sum[1], src[2], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[1], 4); // line4 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[2], 5); // line5 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 20, 21, 22, 23, 24); + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); + LOAD_5_SRC(src[0], 6); // line6 + CALC_ONE_LINE_4_RESULT(sum[1], src[2], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 20, 21, 22, 23, 24); + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; + int8x16_t src[3][5]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(7, cb); +#undef cb + LOAD_5_SRC(src[0], 0); // line0 + LOAD_5_SRC(src[1], 1); // line1 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[2], 2); // line2 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[0], 3); // line3 + CALC_ONE_LINE_4_RESULT(sum[0], src[2], 10, 11, 12, 13, 14); + CALC_ONE_LINE_4_RESULT(sum[1], src[2], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[1], 4); // line4 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[2], 5); // line5 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 20, 21, 22, 23, 24); + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); + LOAD_5_SRC(src[0], 6); // line6 + CALC_ONE_LINE_4_RESULT(sum[1], src[2], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 20, 21, 22, 23, 24); + + STORE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 8 <= OW; ow += 8) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; + int8x16_t src[3][10]; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + LOAD_10_SRC(src[0], 0); // line0 + LOAD_10_SRC(src[1], 1); // line1 + CALC_ONE_LINE_8_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_10_SRC(src[2], 2); // line2 + CALC_ONE_LINE_8_RESULT(sum, src[1], 5, 6, 7, 8, 9); + LOAD_10_SRC(src[0], 3); // line3 + CALC_ONE_LINE_8_RESULT(sum, src[2], 10, 11, 12, 13, 14); + LOAD_10_SRC(src[1], 4); // line4 + CALC_ONE_LINE_8_RESULT(sum, src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_8_RESULT(sum, src[1], 20, 21, 22, 23, 24); + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum); + } + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + + int16x8_t sum[2] = {init_sum, init_sum}; +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + + int8x16_t src[3][5]; + LOAD_5_SRC(src[0], 0); // line0 + LOAD_5_SRC(src[1], 1); // line1 + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[2], 2); // line2 + CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[0], 3); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[2], 10, 11, 12, 13, 14); + LOAD_5_SRC(src[1], 4); // line4 + CALC_ONE_LINE_4_RESULT(sum, src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum, src[1], 20, 21, 22, 23, 24); + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum[2] = {init_sum, init_sum}; + +#define cb(i) \ + const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + UNROLL_CALL_NOWRAPPER(5, cb); +#undef cb + + int8x16_t src[3][5]; + LOAD_5_SRC(src[0], 0); // line0 + LOAD_5_SRC(src[1], 1); // line1 + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_5_SRC(src[2], 2); // line2 + CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9); + LOAD_5_SRC(src[0], 3); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[2], 10, 11, 12, 13, 14); + LOAD_5_SRC(src[1], 4); // line4 + CALC_ONE_LINE_4_RESULT(sum, src[0], 15, 16, 17, 18, 19); + CALC_ONE_LINE_4_RESULT(sum, src[1], 20, 21, 22, 23, 24); + + STORE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +} +#undef CALC_ONE_LINE_8_RESULT +#undef CALC_ONE_LINE_4_RESULT +#undef LOAD_10_SRC +#undef LOAD_5_SRC +#elif MEGDNN_ARMV7 +template +void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW) { + MEGDNN_MARK_USED_VAR(IH); + const int16_t* __restrict bptr = bias; + const int32_t* tmp_filter = reinterpret_cast(filter); + INIT_SUM(); + int8x8_t kern0[3], kern1[3], kern2[3], kern3[3], kern4[3]; + + int32x2_t tmp_kern = vdup_n_s32(tmp_filter[4]); + tmp_kern = vset_lane_s32(0,tmp_kern,1); + kern0[0] = vld1_s8(filter); + kern0[1] = vld1_s8(filter + 8); + kern0[2] = vreinterpret_s8_s32(tmp_kern); + + tmp_kern = vdup_n_s32(tmp_filter[9]); + tmp_kern = vset_lane_s32(0,tmp_kern,1); + kern1[0] = vld1_s8(filter + 20); + kern1[1] = vld1_s8(filter + 28); + kern1[2] = vreinterpret_s8_s32(tmp_kern); + + tmp_kern = vdup_n_s32(tmp_filter[14]); + tmp_kern = vset_lane_s32(0,tmp_kern,1); + kern2[0] = vld1_s8(filter + 40); + kern2[1] = vld1_s8(filter + 48); + kern2[2] = vreinterpret_s8_s32(tmp_kern); + + tmp_kern = vdup_n_s32(tmp_filter[19]); + tmp_kern = vset_lane_s32(0,tmp_kern,1); + kern3[0] = vld1_s8(filter + 60); + kern3[1] = vld1_s8(filter + 68); + kern3[2] = vreinterpret_s8_s32(tmp_kern); + + tmp_kern = vdup_n_s32(tmp_filter[24]); + tmp_kern = vset_lane_s32(0,tmp_kern,1); + kern4[0] = vld1_s8(filter + 80); + kern4[1] = vld1_s8(filter + 88); + kern4[2] = vreinterpret_s8_s32(tmp_kern); + +#define LOAD_3_SRC_ARRAY(_src,_sptr)\ + _src[0] = vld1q_s8(_sptr);/*0 1 2 3 */\ + _src[1] = vld1q_s8(_sptr + 16);/*4 5 6 7 */\ + _src[2] = vld1q_s8(_sptr + 32);/*8 9 10 11*/ + +#define CALC_ONE_LINE(_src, _kern, _sum) \ + tmpsum0 = vmull_s8(vget_low_s8(_src[0]), _kern[0]); /*01*/ \ + tmpsum1 = vmull_s8(vget_high_s8(_src[0]), _kern[0]); /*23*/ \ + tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[0]), _kern[1]); /*23*/ \ + tmpsum1 = vmlal_s8(tmpsum1, vget_low_s8(_src[1]), _kern[1]); /*45*/ \ + tmpsum0 = vmlal_s8(tmpsum0, vget_low_s8(_src[1]), _kern[2]); /*4*/ \ + tmpsum1 = vmlal_s8(tmpsum1, vget_high_s8(_src[1]), _kern[2]); /*6*/ \ + res0 = vadd_s16(vget_low_s16(tmpsum0), vget_high_s16(tmpsum0)); \ + res1 = vadd_s16(vget_low_s16(tmpsum1), vget_high_s16(tmpsum1)); \ + _sum[0] = vaddq_s16(_sum[0], vcombine_s16(res0, res1)); \ + \ + tmpsum0 = vmull_s8(vget_low_s8(_src[1]), _kern[0]); /*45*/ \ + tmpsum1 = vmull_s8(vget_high_s8(_src[1]), _kern[0]); /*67*/ \ + tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[1]), _kern[1]); /*67*/ \ + tmpsum1 = vmlal_s8(tmpsum1, vget_low_s8(_src[2]), _kern[1]); /*89*/ \ + tmpsum0 = vmlal_s8(tmpsum0, vget_low_s8(_src[2]), _kern[2]); /*8*/ \ + tmpsum1 = vmlal_s8(tmpsum1, vget_high_s8(_src[2]), _kern[2]); /*10*/ \ + res0 = vadd_s16(vget_low_s16(tmpsum0), vget_high_s16(tmpsum0)); \ + res1 = vadd_s16(vget_low_s16(tmpsum1), vget_high_s16(tmpsum1)); \ + _sum[1] = vaddq_s16(_sum[1], vcombine_s16(res0, res1)); + +#define CALC_8_RESULT() \ + LOAD_3_SRC_ARRAY(src0, sptr0); \ + LOAD_3_SRC_ARRAY(src1, sptr1); \ + CALC_ONE_LINE(src0, kern0, sum[0]); \ + \ + LOAD_3_SRC_ARRAY(src0, sptr2); \ + CALC_ONE_LINE(src1, kern1, sum[0]); \ + \ + LOAD_3_SRC_ARRAY(src1, sptr3); \ + CALC_ONE_LINE(src0, kern2, sum[0]); \ + CALC_ONE_LINE(src0, kern0, sum[1]); \ + \ + LOAD_3_SRC_ARRAY(src0, sptr4); \ + CALC_ONE_LINE(src1, kern3, sum[0]); \ + CALC_ONE_LINE(src1, kern1, sum[1]); \ + \ + LOAD_3_SRC_ARRAY(src1, sptr5); \ + CALC_ONE_LINE(src0, kern4, sum[0]); \ + CALC_ONE_LINE(src0, kern2, sum[1]); \ + \ + LOAD_3_SRC_ARRAY(src0, sptr6); \ + CALC_ONE_LINE(src1, kern3, sum[1]); \ + CALC_ONE_LINE(src0, kern4, sum[1]); + +#define CALC_4_RESULT() \ + LOAD_3_SRC_ARRAY(src0, sptr0); \ + LOAD_3_SRC_ARRAY(src1, sptr1); \ + CALC_ONE_LINE(src0, kern0, sum); \ + \ + LOAD_3_SRC_ARRAY(src0, sptr2); \ + CALC_ONE_LINE(src1, kern1, sum); \ + \ + LOAD_3_SRC_ARRAY(src1, sptr3); \ + CALC_ONE_LINE(src0, kern2, sum); \ + \ + LOAD_3_SRC_ARRAY(src0, sptr4); \ + CALC_ONE_LINE(src1, kern3, sum); \ + CALC_ONE_LINE(src0, kern4, sum); + + size_t oh = 0_z; + for (; oh + 2 <= OH; oh += 2) { + size_t ih = oh * 2; + size_t ow = 0_z; + + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + int8x16_t src0[3], src1[3]; + int16x8_t tmpsum0, tmpsum1; + int16x4_t res0, res1; + CALC_8_RESULT(); + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); + } + if (ow < OW) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4; + const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4; + int16x8_t sum[2][2]; +#define cb(j) \ + sum[0][j] = init_sum; \ + sum[1][j] = init_sum; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + int8x16_t src0[3], src1[3]; + int16x8_t tmpsum0, tmpsum1; + int16x4_t res0, res1; + + CALC_8_RESULT(); + STORE_REMAIN(dst, oh, ow, OW, sum[0],remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1],remain); + } + } + for (; oh < OH; oh++) { + size_t ih = oh * 2; + size_t ow = 0_z; + for (; ow + 4 <= OW; ow += 4) { + size_t iw = ow * 2; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + + int16x8_t sum[2]={init_sum,init_sum}; + + int8x16_t src0[3], src1[3]; + int16x8_t tmpsum0, tmpsum1; + int16x4_t res0, res1; + CALC_4_RESULT(); + STORE_1_LINE_4_RESULT(dst, oh,ow, OW, sum); + } + if (OW > ow) { + size_t iw = ow * 2; + size_t remain = OW - ow; + const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; + const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; + const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; + const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; + const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; + int16x8_t sum[2] = {init_sum, init_sum}; + int8x16_t src0[3], src1[3]; + int16x8_t tmpsum0, tmpsum1; + int16x4_t res0, res1; + CALC_4_RESULT(); + STORE_REMAIN(dst, oh, ow, OW, sum, remain); + } + } +} +#undef CALC_ONE_LINE +#undef CALC_4_RESULT +#undef CALC_8_RESULT +#undef LOAD_3_SRC_ARRAY +#endif + +#undef INIT_SUM +#undef STORE_1_LINE_RESULT +#undef STORE_1_LINE_4_RESULT +#undef STORE_REMAIN + +#define INSTANTIATION(stride, i, bias) \ + template void channel_wise_nchw44_8x8x16:: \ + direct_##stride##_##i##x##i##_int8x8x16( \ + const int8_t*, const int8_t*, const int16_t*, void*, \ + const size_t, const size_t, const size_t, const size_t); + +#define FOR_OP(stride, i, bias) INSTANTIATION(stride, i, bias) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h new file mode 100644 index 0000000000000000000000000000000000000000..6e214b823e559f739a100f166c02ad6a46d1486d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h @@ -0,0 +1,57 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44 { + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +namespace stride1 { + +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); +} // namespace stride1 + +namespace stride2 { +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); + +} // namespace stride2 +} // namespace direct_int8_stride1 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e57d26f20864bc24a875bad1ef8d574e327b9b77 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp @@ -0,0 +1,259 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h" +#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +using namespace channel_wise_nchw44_8x8x16; + +namespace { +void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, + size_t& IH2, size_t& IW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + size_t OW2 = (OW + 3) & ~3; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +} // namespace + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1) +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2) + + +WorkspaceBundle stride1::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2; + get_rectified_size(param, IH2, IW2); + constexpr size_t pack_ic_size = 4_z; + //! The extra 16B is used to void ivalid read in kernel compute + size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16; + SmallVector sizes(nr_threads, src_size); + return {nullptr, sizes}; +} + +//! compute one output channel +template +void stride1::do_conv_kern(const WorkspaceBundle& bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + + constexpr size_t pack_group_size = 4_z; + constexpr size_t pack_ic_size = 4_z; + + size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + int8_t* padding_src = static_cast(bundle.get(thread_id)); + const int8_t* sptr = + kern_param.src(batch_id, group_id, 0, pack_group_size); + const int8_t* fptr = kern_param.filter(group_id, pack_group_size); + void* dst = kern_param.dst(batch_id, group_id, 0, pack_group_size); + const int16_t* bptr = + kern_param.bias(batch_id, group_id, 0, pack_group_size); + //! copy in case of illegal read src when padding is zero + std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); + rep(ih, IH) { + std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, + sizeof(int8_t) * IW * pack_ic_size); + } + sptr = padding_src; + +#define KERN(_size) \ + direct_stride1_##_size##x##_size##_int8x8x16( \ + sptr, fptr, bptr, dst, IH2, IW2, OH, OW); + DISPATCH_FILTER_CHANNEL_WISE(filter, KERN); +#undef KERN +} + +SmallVector stride1::get_kimpls( + const NCBKernSizeParam& param) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t group = fm.group / 4; + megdnn_assert(fm.group % 4 == 0, + "nchw44 channel wise conv with group is not times of 4"); + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1, \ + midout_iv(#filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode) \ + break; \ + default: \ + megdnn_assert(0, "only support NonlineMode::IDENTITY"); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0, \ + "only support BiasMode::NO_BIAS and " \ + "BiasMode::BROADCAST_CHANNEL_BIAS"); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + default: \ + megdnn_assert(0, "only support filtersize 2x2 3x3 5x5"); \ + break; \ + } + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + do_conv_fun(wbundle, kern_param, ncb_index); + }; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; +#undef DO_CONV_KERN_FUN +} + +WorkspaceBundle stride2::get_bundle( + const ConvBiasImpl::NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2; + get_rectified_size(param, IH2, IW2); + constexpr size_t pack_ic_size = 4_z; + //! The extra 16B is used to void ivalid read in kernel compute + size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16; + SmallVector sizes(nr_threads, src_size); + return {nullptr, sizes}; +} + +//! compute one output channel +template +void stride2::do_conv_kern(const WorkspaceBundle& bundle, + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t IH2, IW2; + get_rectified_size(kern_param, IH2, IW2); + + constexpr size_t pack_group_size = 4_z; + constexpr size_t pack_ic_size = 4_z; + + size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + int8_t* padding_src = static_cast(bundle.get(thread_id)); + const int8_t* sptr = + kern_param.src(batch_id, group_id, 0, pack_group_size); + const int8_t* fptr = kern_param.filter(group_id, pack_group_size); + void* dst = kern_param.dst(batch_id, group_id, 0, pack_group_size); + const int16_t* bptr = + kern_param.bias(batch_id, group_id, 0, pack_group_size); + //! copy in case of illegal read src when padding is zero + std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); + rep(ih, IH) { + std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, + sizeof(int8_t) * IW * pack_ic_size); + } + sptr = padding_src; + +#define KERN(_size) \ + direct_stride2_##_size##x##_size##_int8x8x16( \ + sptr, fptr, bptr, dst, IH2, IW2, OH, OW); + DISPATCH_FILTER_CHANNEL_WISE(filter, KERN); +#undef KERN +} + +SmallVector stride2::get_kimpls( + const NCBKernSizeParam& param) { + auto fm = param.filter_meta; + size_t N = param.n; + size_t group = fm.group / 4; + megdnn_assert(fm.group % 4 == 0, + "nchw44 channel wise conv with group is not times of 4"); + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2, \ + midout_iv(#filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + + DISPATCH_CONV_KERN(); + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + auto exec_one_group = [wbundle, do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + wbundle.set(kern_param.workspace_ptr); + do_conv_fun(wbundle, kern_param, ncb_index); + }; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; +#undef DISPATCH_CONV_KERN +#undef GET_BIAS_MODE_PARAM +#undef GET_OP_PARAM +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h new file mode 100644 index 0000000000000000000000000000000000000000..b7b3dc4c6984fd4a26969a96896e290a9cffb4ab --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h @@ -0,0 +1,57 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44_8x8x16 { + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +namespace stride1 { + +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); +} // namespace stride1 + +namespace stride2 { +bool is_available(const NCBKernSizeParam& param); + +WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +template +void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); + +SmallVector get_kimpls(const NCBKernSizeParam& param); + +} // namespace stride2 +} // namespace direct_int8_stride1 +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 515166f6d240be07ffc32e94709ad975237a5489..c7192803a0576b25d8794ad995e148dcc1045f80 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -48,6 +48,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8DirectStride1 s8_direct_stride1; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; + AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44; #if __ARM_FEATURE_DOTPROD AlgoDotS8DirectStride1 ds8_direct_stride1; @@ -95,6 +96,7 @@ public: direct_algos.emplace_back(&s8_direct_nchw_nchw44); direct_algos.emplace_back(&s8_direct_stride1); + direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 6a6b1d609938443efd4092cd6c2ddd225e35e9e6..780c07b90b025738e58a4dee33022f7665d98ecb 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -54,6 +54,7 @@ private: class AlgoS8ChanWiseStride1NCHW44; class AlgoS8ChanWiseStride2NCHW44; + class AlgoS8x8x16ChanWiseStride1Stride2NCHW44; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoFP16WinogradF23; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index ca33ed8a94aacc9bfeae876ed82a56b33e51b29b..7afa3408b399e980129bbb137ec51e8581bad57c 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -558,6 +558,142 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, } } +std::vector get_nchw44_channel_wise_benchmark_args( + std::vector kernel, size_t stride, bool no_bias, + bool no_nonlinemode, bool no_full_bias) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nlmode, bool pad) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + param.format = param::ConvBias::Format::NCHW44; + param.sparse = param::ConvBias::Sparse::GROUP; + + args.emplace_back(param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{}); + if (!no_bias) { + args.emplace_back(param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{1, group, 1, 1, 4}); + } + if (!no_full_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{n, group, + (h + 2 * param.pad_w - kernel) / stride + 1, + (w + 2 * param.pad_w - kernel) / stride + 1, + 4}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + for (size_t n : {1}) { + for (auto nlmode : nonlinemode) { + for (bool pad : {true}) { + for (size_t group : {1, 2, 4, 128}) { + for (size_t size : {40,89,100,200}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + for (bool pad : {false}) { + for (size_t group : {1, 2, 4, 8, 16, 32, 64, 128}) { + for (size_t size : {40, 89, 100}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + } + } + return args; +} + +void BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32(const char* algo_name0, + const char* algo_name1, Handle* handle, + size_t kernel,size_t stride = 1, size_t pack_size = 1) { + +auto args = get_nchw44_channel_wise_benchmark_args({2, 3, 5}, stride, false, true, true); + + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark(handle); + benchmark.set_display(false); + benchmark.set_times(RUN); + benchmark.set_dtype(0, dtype::Int8()); + benchmark.set_dtype(1, dtype::Int8()); + benchmark.set_dtype(2, dtype::Int32()); + benchmark.set_dtype(4, dtype::Int32()); + + Benchmarker benchmark_algo1(handle); + benchmark_algo1.set_display(false); + benchmark_algo1.set_times(RUN); + benchmark_algo1.set_dtype(0, dtype::Int8()); + benchmark_algo1.set_dtype(1, dtype::Int8()); + benchmark_algo1.set_dtype(2, dtype::Int16()); + benchmark_algo1.set_dtype(4, dtype::Int16()); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 * pack_size/ + (1024 * 1024 * 1024) * 1e3; + + benchmark.set_param(arg.param); + auto used = algo_benchmark(benchmark, + {arg.src, arg.filter, {}, {}, {}}, + algo_name0) / + RUN; + + arg.param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + arg.param.format = param::ConvBias::Format::NCHW44; + benchmark_algo1.set_param(arg.param); + + auto used_algo1 = + algo_benchmark( + benchmark_algo1, + {arg.src, arg.filter, {}, {}, {}}, + algo_name1) / + RUN; + printf("%s %s: normal: %f ms %f Gflops 8x8x16: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used, computations / used, used_algo1, + computations / used_algo1, used / used_algo1); + } +} + #if MEGDNN_AARCH64 TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { printf("=========================compare " @@ -579,6 +715,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { } #endif +TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) { + BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44", + "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44", + handle(), 3,1,4); +} +TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE2) { + BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD2_NCHW44", + "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44", + handle(), 3,2, 4); +} + TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONVBIAS_QUANTIZED) { constexpr size_t RUNS = 50; param::ConvBias param; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 9e99311148d2ed74980783a893fd8b7e1d886fd7..f4f35b508def941e7d813642ee313cecf65b199b 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ +#include "megdnn/dtype.h" #include "test/arm_common/fixture.h" #include "test/common/benchmarker.h" #include "test/common/conv_bias.h" @@ -475,6 +476,36 @@ TEST_F(ARM_COMMON_MULTI_THREADS, handle(), "S8_CHAN_WISE_STRD2_NCHW44"); } +TEST_F(ARM_COMMON, + CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int16()); + checker.set_dtype(4, dtype::Int16()); + auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} + +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) { + Checker checker(handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int16()); + checker.set_dtype(4, dtype::Int16()); + auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true); + for (auto&& arg : args) { + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + } +} + /********************************qint8 direct******************************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) { checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 58b62d8a9a41a951ff7f32e24351669fcb02405c..2d07c5b98ff1c2b76e3987a3f6d7f62cc09a5b35 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -1706,6 +1706,77 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, {1, {4}}, data_type); } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CHANNEL_WISE_INT8_INT8_INT16_STRIDE1) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + param.format = param::ConvBias::Format::NCHW44; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS, + size_t P) { + size_t group = IC; + size_t OC = IC; + size_t S = 1; + SmallVector shapes{ + {N, IC, H, W, 4}, + {group, 1, 1, FS, FS, 4}, + {1, OC, 1, 1, 4}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1, 4}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1, 4}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + bench_case(1, 128, 200, 200, 3, 1); + bench_case(1, 128, 128, 128, 3, 1); + bench_case(1, 128, 100, 100, 3, 1); + bench_case(1, 128, 80, 80, 3, 1); + bench_case(1, 128, 56, 56, 3, 1); + bench_case(1, 128, 28, 28, 3, 1); + bench_case(1, 128, 14, 14, 3, 1); + + bench_case(1, 64, 200, 200, 3, 1); + bench_case(1, 64, 128, 128, 3, 1); + bench_case(1, 64, 100, 100, 3, 1); + bench_case(1, 64, 80, 80, 3, 1); + bench_case(1, 64, 56, 56, 3, 1); + bench_case(1, 64, 28, 28, 3, 1); + bench_case(1, 64, 14, 14, 3, 1); + + bench_case(1, 32, 200, 200, 3, 1); + bench_case(1, 32, 128, 128, 3, 1); + bench_case(1, 32, 100, 100, 3, 1); + bench_case(1, 32, 80, 80, 3, 1); + bench_case(1, 32, 56, 56, 3, 1); + bench_case(1, 32, 28, 28, 3, 1); + bench_case(1, 32, 14, 14, 3, 1); + + std::string algo_name = "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; + printf("Benchmarker S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44 algo\n"); + std::vector data_type = {dtype::Int8(), dtype::Int8(), + dtype::Int16(), dtype::Int16()}; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); +} + + TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) { constexpr size_t RUNS = 50;