提交 a773d076 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm_common): add nchw44 8x8x16 channel wise conv

stride1 2x2 3x3 5x5 stride2 2x2 3x3 5x5

GitOrigin-RevId: 43d76311c2c914911e41d677fcf77d27f7ee8058
上级 09b5f3d4
......@@ -33,7 +33,7 @@ KERN(stride2, 5)
#undef KERN
} // namesapce conv_bias
} // namespace channel_wise_nchw44
} // namespace arm_common
} // namespace megdnn
......
......@@ -10,16 +10,15 @@
*/
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h"
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h"
#include "src/arm_common/conv_bias/int8x8x16/conv_direct.h"
#include "src/arm_common/conv_bias/int8x8x16/conv_stride2.h"
#include "midout.h"
#include "src/common/opr_delegate.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_kimpl)
#include <atomic>
#include <cstring>
#include <mutex>
using namespace megdnn;
using namespace arm_common;
......@@ -550,4 +549,70 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns(
return {{kern, {group, 1_z, 1_z}}};
}
/* =====================8int8x8x16 channel_wise_nchw44 stride1 stride2 algo ===================== */
bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool avaible =
//! src and filter are int8, dst is int16
(param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16) &&
fm.format == param::Convolution::Format::NCHW44 &&
param.bias_mode != megdnn::BiasMode::BIAS &&
param.nonlineMode == megdnn::NonlineMode::IDENTITY &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
(fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2)) &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) &&
fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0;
return avaible;
}
size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
}
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
if (stride_h == stride_w && stride_h == 1) {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv(
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44_dispatch_kerns"_hash)) {
return channel_wise_nchw44_8x8x16::stride1::get_kimpls(param);
}
MIDOUT_END();
return {};
} else if (stride_h == stride_w && stride_h == 2) {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv(
"AlgoS8x8x16ChanWiseStride2NCHW44_dispatch_kerns"_hash)) {
return channel_wise_nchw44_8x8x16::stride2::get_kimpls(param);
}
MIDOUT_END();
return {};
} else {
return {};
}
}
// vim: syntax=cpp.doxygen
......@@ -72,6 +72,18 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
};
} // namespace arm_common
} // namespace megdnn
......
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_8x8x16 {
#define KERN(stride, i) \
template <BiasMode bias_mode> \
void direct_##stride##_##i##x##i##_int8x8x16( \
const int8_t* src, const int8_t* filter, const int16_t* bias, \
void* dst, const size_t IH, const size_t IW, const size_t OH, \
const size_t OW);
KERN(stride1, 2)
KERN(stride1, 3)
KERN(stride1, 5)
KERN(stride2, 2)
KERN(stride2, 3)
KERN(stride2, 5)
#undef KERN
} // namespace channel_wise_nchw44_8x8x16
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
#define INIT_SUM() \
int16x8_t init_sum; \
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { \
int16x4_t tmpsum = vld1_s16(bptr); \
init_sum = vcombine_s16(tmpsum, tmpsum); \
} else { \
init_sum = vdupq_n_s16(0); \
}
#define STORE_1_LINE_RESULT(dst, oh, ow, OW, sum) \
do { \
dt_int16* dptr = \
reinterpret_cast<dt_int16*>(dst) + (oh)*OW * 4 + ow * 4; \
vst1q_s16(dptr, sum[0]); \
vst1q_s16(dptr + 8, sum[1]); \
vst1q_s16(dptr + 16, sum[2]); \
vst1q_s16(dptr + 24, sum[3]); \
} while (0);
#define STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum) \
do { \
dt_int16* dptr = \
reinterpret_cast<dt_int16*>(dst) + (oh)*OW * 4 + ow * 4; \
vst1q_s16(dptr, sum[0]); \
vst1q_s16(dptr + 8, sum[1]); \
} while (0);
#define STORE_REMAIN(dst, oh, ow, OW, sum, remain) \
do { \
dt_int16* dptr = \
reinterpret_cast<dt_int16*>(dst) + oh * OW * 4 + ow * 4; \
if (remain == 1) { \
vst1_s16(dptr, vget_low_s16(sum[0])); \
} else if (remain == 2) { \
vst1q_s16(dptr, sum[0]); \
} else if (remain == 3) { \
vst1q_s16(dptr, sum[0]); \
vst1_s16(dptr + 8, vget_low_s16(sum[1])); \
} \
} while (0);
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride1_2x2_int8x8x16(
const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* fptr = reinterpret_cast<const int*>(filter);
int8x8_t kern[4];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i));
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define LOAD_SRC(_sptr, _src) \
_src[0] = vld1q_s8(_sptr); \
_src[1] = vld1q_s8(_sptr + 16); \
_src[1] = vextq_s8(_src[0], _src[1], 4);
#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]);
#define LOAD_SRC_8(_sptr, _src) \
_src[0] = vld1q_s8(_sptr); \
_src[2] = vld1q_s8(_sptr + 16); \
_src[3] = vld1q_s8(_sptr + 32); \
_src[1] = vextq_s8(_src[0], _src[2], 4); \
_src[3] = vextq_s8(_src[2], _src[3], 4);
#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1)\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[2]),kern[_kid0]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[2]),kern[_kid0]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[3]),kern[_kid1]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[3]),kern[_kid1]);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh;
size_t ow = 0_z;
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[2][4];
int8x16_t src[2][4];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
LOAD_SRC_8(sptr0, src[0]);
LOAD_SRC_8(sptr1, src[1]);
CALC_ONE_LINE_8_RESULT(sum[0], src[0], 0, 1);
LOAD_SRC_8(sptr2, src[0]);
CALC_ONE_LINE_8_RESULT(sum[0], src[1], 2, 3);
CALC_ONE_LINE_8_RESULT(sum[1], src[1], 0, 1);
CALC_ONE_LINE_8_RESULT(sum[1], src[0], 2, 3);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
int8x16_t src[2][2];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
LOAD_SRC(sptr0, src[0]);
LOAD_SRC(sptr1, src[1]);
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1);
LOAD_SRC(sptr2, src[0]);
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 2, 3);
CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 2, 3);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
if (ow < OW) {
size_t iw = ow;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
int8x16_t src[2][2];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
LOAD_SRC(sptr0, src[0]);
LOAD_SRC(sptr1, src[1]);
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1);
LOAD_SRC(sptr2, src[0]);
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 2, 3);
CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 2, 3);
STORE_REMAIN(dst, oh, ow, OW, sum[0], remain);
STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh;
size_t ow = 0_z;
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[4];
int8x16_t src[2][4];
#define cb(i) sum[i] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
LOAD_SRC_8(sptr0, src[0]);
LOAD_SRC_8(sptr1, src[1]);
CALC_ONE_LINE_8_RESULT(sum, src[0], 0, 1);
CALC_ONE_LINE_8_RESULT(sum, src[1], 2, 3);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum);
}
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[2];
int8x16_t src[2][2];
sum[0] = init_sum;
sum[1] = init_sum;
LOAD_SRC(sptr0, src[0]);
LOAD_SRC(sptr1, src[1]);
CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1);
CALC_ONE_LINE_4_RESULT(sum, src[1], 2, 3);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum);
}
if (ow < OW) {
size_t iw = ow;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[2];
int8x16_t src[2][2];
sum[0] = init_sum;
sum[1] = init_sum;
LOAD_SRC(sptr0, src[0]);
LOAD_SRC(sptr1, src[1]);
CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1);
CALC_ONE_LINE_4_RESULT(sum, src[1], 2, 3);
STORE_REMAIN(dst, oh, ow, OW, sum, remain);
}
}
}
#undef CALC_ONE_LINE_4_RESULT
#undef CALC_ONE_LINE_8_RESULT
#undef LOAD_SRC
#undef LOAD_SRC_8
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride1_3x3_int8x8x16(
const int8_t* sptr, const int8_t* fptr, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* filter = reinterpret_cast<const int*>(fptr);
int8x8_t kern[9];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(filter + i));
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb
#define LOAD_6_SRC(src, sptr0) \
src[0] = vld1q_s8(sptr0); \
src[1] = vld1q_s8(sptr0 + 16); \
tmp_src0 = vld1q_s8(sptr0 + 32); \
src[2] = vextq_s8(src[0], src[1], 4); \
src[3] = vextq_s8(src[1], tmp_src0, 4); \
src[4] = vextq_s8(src[0], src[1], 8); \
src[5] = vextq_s8(src[1], tmp_src0, 8);
#define LOAD_3_SRC(sptr, src) \
src[0] = vld1q_s8(sptr); \
src[2] = vld1q_s8(sptr + 16); \
src[1] = vextq_s8(src[0], src[2], 4); \
src[2] = vextq_s8(src[0], src[2], 8);
#define CALC_ONE_LINE(_src, _kern0, _kern1, _kern2, _sum) \
_sum[0] = vmlal_s8(_sum[0], _kern0, vget_low_s8(_src[0])); \
_sum[1] = vmlal_s8(_sum[1], _kern0, vget_high_s8(_src[0])); \
_sum[0] = vmlal_s8(_sum[0], _kern1, vget_low_s8(_src[1])); \
_sum[1] = vmlal_s8(_sum[1], _kern1, vget_high_s8(_src[1])); \
_sum[0] = vmlal_s8(_sum[0], _kern2, vget_low_s8(_src[2])); \
_sum[1] = vmlal_s8(_sum[1], _kern2, vget_high_s8(_src[2]));
#define CALC_ONE(_src, _i, _j, _kern, _sum) \
_sum[0] = vmlal_s8(_sum[0], _kern, vget_low_s8(_src[_i])); \
_sum[1] = vmlal_s8(_sum[1], _kern, vget_high_s8(_src[_i])); \
_sum[2] = vmlal_s8(_sum[2], _kern, vget_low_s8(_src[_j])); \
_sum[3] = vmlal_s8(_sum[3], _kern, vget_high_s8(_src[_j]));
size_t oh = 0_z;
for (; oh + 3 <= OH; oh += 3) {
size_t ih = oh;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum0[4], sum1[4], sum2[4];
#define cb(j) \
sum0[j] = init_sum; \
sum1[j] = init_sum; \
sum2[j] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
int8x16_t src[2][6];
int8x16_t tmp_src0;
LOAD_6_SRC(src[0], sptr0); //! line0
LOAD_6_SRC(src[1], sptr1); //! line1
CALC_ONE(src[0], 0, 1, kern[0], sum0);
CALC_ONE(src[0], 2, 3, kern[1], sum0);
CALC_ONE(src[0], 4, 5, kern[2], sum0);
CALC_ONE(src[1], 0, 1, kern[3], sum0);
CALC_ONE(src[1], 2, 3, kern[4], sum0);
CALC_ONE(src[1], 4, 5, kern[5], sum0);
LOAD_6_SRC(src[0], sptr2); //! line2
CALC_ONE(src[0], 0, 1, kern[6], sum0);
CALC_ONE(src[0], 2, 3, kern[7], sum0);
CALC_ONE(src[0], 4, 5, kern[8], sum0);
CALC_ONE(src[1], 0, 1, kern[0], sum1);
CALC_ONE(src[1], 2, 3, kern[1], sum1);
CALC_ONE(src[1], 4, 5, kern[2], sum1);
CALC_ONE(src[0], 0, 1, kern[3], sum1);
CALC_ONE(src[0], 2, 3, kern[4], sum1);
CALC_ONE(src[0], 4, 5, kern[5], sum1);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum0)
LOAD_6_SRC(src[1], sptr3); //! line3
CALC_ONE(src[1], 0, 1, kern[6], sum1);
CALC_ONE(src[1], 2, 3, kern[7], sum1);
CALC_ONE(src[1], 4, 5, kern[8], sum1);
CALC_ONE(src[0], 0, 1, kern[0], sum2);
CALC_ONE(src[0], 2, 3, kern[1], sum2);
CALC_ONE(src[0], 4, 5, kern[2], sum2);
CALC_ONE(src[1], 0, 1, kern[3], sum2);
CALC_ONE(src[1], 2, 3, kern[4], sum2);
CALC_ONE(src[1], 4, 5, kern[5], sum2);
LOAD_6_SRC(src[0], sptr4); //! line4
STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum1)
CALC_ONE(src[0], 0, 1, kern[6], sum2);
CALC_ONE(src[0], 2, 3, kern[7], sum2);
CALC_ONE(src[0], 4, 5, kern[8], sum2);
STORE_1_LINE_RESULT(dst, (oh + 2), ow, OW, sum2)
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum0[2], sum1[2], sum2[2];
#define cb(j) \
sum0[j] = init_sum; \
sum1[j] = init_sum; \
sum2[j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
int8x16_t src[2][3];
LOAD_3_SRC(sptr0,src[0]);
LOAD_3_SRC(sptr1,src[1]);
CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum0);//line0
CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1
CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1
LOAD_3_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum0)
CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2
CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2
LOAD_3_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3
STORE_1_LINE_4_RESULT(dst, (oh+1), ow, OW, sum1)
CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3
LOAD_3_SRC(sptr4,src[0]);
CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4
STORE_1_LINE_4_RESULT(dst, (oh+2), ow, OW, sum2)
}
if (ow < OW) {
size_t iw = ow;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum0[2], sum1[2], sum2[2];
int8x16_t src[2][3];
#define cb(j) \
sum0[j] = init_sum; \
sum1[j] = init_sum; \
sum2[j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
LOAD_3_SRC(sptr0,src[0]);//line2
LOAD_3_SRC(sptr1,src[1]);//line2
CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum0); // line0
CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1
CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1
LOAD_3_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2
STORE_REMAIN(dst, (oh+0), ow, OW, sum0,remain)
CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2
CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2
LOAD_3_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3
STORE_REMAIN(dst, (oh+1), ow, OW, sum1,remain)
CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3
LOAD_3_SRC(sptr4,src[0]);
CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4
STORE_REMAIN(dst, (oh+2), ow, OW, sum2, remain)
}
}
for (; oh < OH; oh++) {
size_t ih = oh;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum0[4];
int8x16_t src[2][6];
int8x16_t tmp_src0;
sum0[0] = init_sum;
sum0[1] = init_sum;
sum0[2] = init_sum;
sum0[3] = init_sum;
LOAD_6_SRC(src[0], sptr0); //! line0
LOAD_6_SRC(src[1], sptr1); //! line1
CALC_ONE(src[0], 0, 1, kern[0], sum0);
CALC_ONE(src[0], 2, 3, kern[1], sum0);
CALC_ONE(src[0], 4, 5, kern[2], sum0);
CALC_ONE(src[1], 0, 1, kern[3], sum0);
CALC_ONE(src[1], 2, 3, kern[4], sum0);
CALC_ONE(src[1], 4, 5, kern[5], sum0);
LOAD_6_SRC(src[0], sptr2); //! line2
CALC_ONE(src[0], 0, 1, kern[6], sum0);
CALC_ONE(src[0], 2, 3, kern[7], sum0);
CALC_ONE(src[0], 4, 5, kern[8], sum0);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum0);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum00[2];
int8x16_t src[2][3];
sum00[0] = init_sum;
sum00[1] = init_sum;
LOAD_3_SRC(sptr0, src[0]);
LOAD_3_SRC(sptr1, src[1]);
CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum00); // line0
CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum00); // line1
LOAD_3_SRC(sptr2, src[0]); // line2
CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum00); // line2
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum00)
}
if (ow < OW) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum00[2];
int8x16_t src[2][3];
sum00[0] = init_sum;
sum00[1] = init_sum;
LOAD_3_SRC(sptr0, src[0]);
LOAD_3_SRC(sptr1, src[1]);
CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum00); // line0
CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum00); // line1
LOAD_3_SRC(sptr2, src[0]); // line2
CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum00); // line2
STORE_REMAIN(dst, oh, ow, OW, sum00,(OW-ow))
}
}
#undef LOAD_3_SRC
#undef LOAD_6_SRC
#undef CALC_ONE
#undef CALC_ONE_LINE
}
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16(
const int8_t* sptr, const int8_t* fptr, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* filter = reinterpret_cast<const int*>(fptr);
int8x8_t kern[25];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(filter + i));
UNROLL_CALL_NOWRAPPER(25, cb);
#undef cb
#define LOAD_1_LINE_SRC(sptr, src) \
src[0] = vld1q_s8(sptr); \
src[4] = vld1q_s8(sptr + 16); \
src[1] = vextq_s8(src[0], src[4], 4); \
src[2] = vextq_s8(src[0], src[4], 8); \
src[3] = vextq_s8(src[0], src[4], 12);
#define LOAD_1_LINE_10_SRC(sptr, src) \
src[0] = vld1q_s8(sptr); \
src[4] = vld1q_s8(sptr + 16); \
src[8] = vld1q_s8(sptr + 32); \
src[1] = vextq_s8(src[0], src[4], 4); \
src[2] = vextq_s8(src[0], src[4], 8); \
src[3] = vextq_s8(src[0], src[4], 12); \
src[5] = vextq_s8(src[4], src[8], 4); \
src[6] = vextq_s8(src[4], src[8], 8); \
src[7] = vextq_s8(src[4], src[8], 12);
#define CALC_ONE_LINE_4_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]);
#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[4]),kern[_kid0]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[4]),kern[_kid0]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[5]),kern[_kid1]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[5]),kern[_kid1]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[6]),kern[_kid2]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[6]),kern[_kid2]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[7]),kern[_kid3]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[7]),kern[_kid3]);\
_sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\
_sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]);\
_sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[8]),kern[_kid4]);\
_sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[8]),kern[_kid4]);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
const int8_t* __restrict sptr5 = sptr4 + IW * 4;
int16x8_t sum[2][4];
int8x16_t src[2][9];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
LOAD_1_LINE_10_SRC(sptr0,src[0]);
LOAD_1_LINE_10_SRC(sptr1,src[1]);
CALC_ONE_LINE_8_RESULT(sum[0],src[0],0,1,2,3,4);
LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_8_RESULT(sum[0],src[1],5,6,7,8,9);//line1
CALC_ONE_LINE_8_RESULT(sum[1],src[1],0,1,2,3,4);//line1
LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_8_RESULT(sum[0],src[0],10,11,12,13,14);//line2
CALC_ONE_LINE_8_RESULT(sum[1],src[0],5,6,7,8,9);//line2
LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_8_RESULT(sum[0],src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_8_RESULT(sum[1],src[1],10,11,12,13,14);//line3
LOAD_1_LINE_10_SRC(sptr5,src[1]);//line5
CALC_ONE_LINE_8_RESULT(sum[0],src[0],20,21,22,23,24);//line4
CALC_ONE_LINE_8_RESULT(sum[1],src[0],15,16,17,18,19);//line3
CALC_ONE_LINE_8_RESULT(sum[1],src[1],20,21,22,23,24);//line3
STORE_1_LINE_RESULT(dst,oh,ow,OW,sum[0]);
STORE_1_LINE_RESULT(dst,(oh+1),ow,OW,sum[1]);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
const int8_t* __restrict sptr5 = sptr4 + IW * 4;
int16x8_t sum[2][2];
int8x16_t src[2][5];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
LOAD_1_LINE_SRC(sptr0,src[0]);
LOAD_1_LINE_SRC(sptr1,src[1]);
CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4);
LOAD_1_LINE_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1
CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1
LOAD_1_LINE_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2
CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2
LOAD_1_LINE_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3
LOAD_1_LINE_SRC(sptr5,src[1]);//line5
CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4
CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3
STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum[0]);
STORE_1_LINE_4_RESULT(dst,(oh+1),ow,OW,sum[1]);
}
if (ow < OW) {
size_t remain = OW - ow;
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
const int8_t* __restrict sptr5 = sptr4 + IW * 4;
int16x8_t sum[2][2];
int8x16_t src[2][5];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
LOAD_1_LINE_SRC(sptr0,src[0]);
LOAD_1_LINE_SRC(sptr1,src[1]);
CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4);
LOAD_1_LINE_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1
CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1
LOAD_1_LINE_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2
CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2
LOAD_1_LINE_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3
LOAD_1_LINE_SRC(sptr5,src[1]);//line5
CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4
CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3
STORE_REMAIN(dst,oh,ow,OW,sum[0],remain);
STORE_REMAIN(dst,(oh+1),ow,OW,sum[1],remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
int16x8_t sum[4];
int8x16_t src[2][9];
#define cb(j) sum[j] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
LOAD_1_LINE_10_SRC(sptr0,src[0]);
LOAD_1_LINE_10_SRC(sptr1,src[1]);
CALC_ONE_LINE_8_RESULT(sum,src[0],0,1,2,3,4);
LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_8_RESULT(sum,src[1],5,6,7,8,9);//line1
LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_8_RESULT(sum,src[0],10,11,12,13,14);//line2
LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_8_RESULT(sum,src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_8_RESULT(sum,src[0],20,21,22,23,24);//line4
STORE_1_LINE_RESULT(dst,oh,ow,OW,sum);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
int16x8_t sum[2];
int8x16_t src[2][5];
sum[0]=init_sum;
sum[1]=init_sum;
LOAD_1_LINE_SRC(sptr0,src[0]);
LOAD_1_LINE_SRC(sptr1,src[1]);
CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4);
LOAD_1_LINE_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1
LOAD_1_LINE_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2
LOAD_1_LINE_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4
STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum);
}
if (ow < OW) {
size_t remain = OW - ow;
size_t iw = ow;
const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = sptr0 + IW * 4;
const int8_t* __restrict sptr2 = sptr1 + IW * 4;
const int8_t* __restrict sptr3 = sptr2 + IW * 4;
const int8_t* __restrict sptr4 = sptr3 + IW * 4;
int16x8_t sum[2];
int8x16_t src[2][5];
sum[0]=init_sum;
sum[1]=init_sum;
LOAD_1_LINE_SRC(sptr0,src[0]);
LOAD_1_LINE_SRC(sptr1,src[1]);
CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4);
LOAD_1_LINE_SRC(sptr2,src[0]);//line2
CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1
LOAD_1_LINE_SRC(sptr3,src[1]);//line3
CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2
LOAD_1_LINE_SRC(sptr4,src[0]);//line4
CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3
CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4
STORE_REMAIN(dst,oh,ow,OW,sum,remain);
}
}
#undef LOAD_1_LINE_SRC
#undef LOAD_1_LINE_10_SRC
#undef CALC_ONE_LINE_4_RESULT
#undef CALC_ONE_LINE_8_RESULT
}
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride2_2x2_int8x8x16(
const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* fptr = reinterpret_cast<const int*>(filter);
int8x8_t kern[4];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i));
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##2), kern[_kid0]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##2), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##3), kern[_kid1]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##3), kern[_kid1]);
#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh * 2;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
int16x8_t sum[2][4];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i)\
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(4,cb)
#undef cb
#define cb(i)\
int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i);\
int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i+8);
UNROLL_CALL_NOWRAPPER(4,cb)
#undef cb
#define cb(i)\
int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i##_00.val[0]);\
int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i##_00.val[1]);\
int8x16_t row##i##2 =vreinterpretq_s8_s32(tmp_row##i##_01.val[0]);\
int8x16_t row##i##3 =vreinterpretq_s8_s32(tmp_row##i##_01.val[1]);
UNROLL_CALL_NOWRAPPER(4,cb)
#undef cb
CALC_ONE_LINE_8_RESULT(sum[0],0,0,1);
CALC_ONE_LINE_8_RESULT(sum[0],1,2,3);
CALC_ONE_LINE_8_RESULT(sum[1],2,0,1);
CALC_ONE_LINE_8_RESULT(sum[1],3,2,3);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i)\
int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i)\
int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\
int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum[0],0,0,1);
CALC_ONE_LINE_4_RESULT(sum[0],1,2,3);
CALC_ONE_LINE_4_RESULT(sum[1],2,0,1);
CALC_ONE_LINE_4_RESULT(sum[1],3,2,3);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
if (ow < OW) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(i) \
sum[0][i] = init_sum; \
sum[1][i] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i)\
int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\
int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum[0],0,0,1);
CALC_ONE_LINE_4_RESULT(sum[0],1,2,3);
CALC_ONE_LINE_4_RESULT(sum[1],2,0,1);
CALC_ONE_LINE_4_RESULT(sum[1],3,2,3);
STORE_REMAIN(dst, (oh+0), ow, OW, sum[0], remain);
STORE_REMAIN(dst, (oh+1), ow, OW, sum[1], remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh * 2;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(2, cb)
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \
int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i + 8);
UNROLL_CALL_NOWRAPPER(2, cb)
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \
int8x16_t row##i##2 = vreinterpretq_s8_s32(tmp_row##i##_01.val[0]); \
int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_01.val[1]);
UNROLL_CALL_NOWRAPPER(2, cb)
#undef cb
CALC_ONE_LINE_8_RESULT(sum, 0, 0, 1);
CALC_ONE_LINE_8_RESULT(sum, 1, 2, 3);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[2]={init_sum,init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i)\
int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\
int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum,0,0,1);
CALC_ONE_LINE_4_RESULT(sum,1,2,3);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum);
}
if (OW > ow) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
int16x8_t sum[2]={init_sum,init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i)\
int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\
int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum,0,0,1);
CALC_ONE_LINE_4_RESULT(sum,1,2,3);
STORE_REMAIN(dst, oh, ow, OW, sum, remain);
}
}
#undef CALC_ONE_LINE_4_RESULT
#undef CALC_ONE_LINE_8_RESULT
}
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16(
const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* fptr = reinterpret_cast<const int*>(filter);
int8x8_t kern[9];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i));
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb
#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1, _kid2) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##3), kern[_kid0]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##3), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##4), kern[_kid1]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##4), kern[_kid1]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##2), kern[_kid2]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##2), kern[_kid2]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##5), kern[_kid2]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##5), kern[_kid2]);
#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1, _kid2) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##2), kern[_kid2]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##2), kern[_kid2]);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh * 2;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2][4];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \
int32x4x2_t tmp_row##i##_03 = vld2q_s32(tmp_sptr##i + 8); \
int32x4_t tmp_row##i = vld1q_s32(tmp_sptr##i + 16);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \
int8x16_t row##i##2 = vreinterpretq_s8_s32( \
vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \
int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \
int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \
int8x16_t row##i##5 = vreinterpretq_s8_s32( \
vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1));
UNROLL_CALL_NOWRAPPER(5, cb)
#undef cb
CALC_ONE_LINE_8_RESULT(sum[0], 0, 0, 1, 2);
CALC_ONE_LINE_8_RESULT(sum[0], 1, 3, 4, 5);
CALC_ONE_LINE_8_RESULT(sum[0], 2, 6, 7, 8);
CALC_ONE_LINE_8_RESULT(sum[1], 2, 0, 1, 2);
CALC_ONE_LINE_8_RESULT(sum[1], 3, 3, 4, 5);
CALC_ONE_LINE_8_RESULT(sum[1], 4, 6, 7, 8);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \
int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \
int8x16_t row##i##2 = \
vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1));
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum[0], 1, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum[0], 2, 6, 7, 8);
CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum[1], 3, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum[1], 4, 6, 7, 8);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
if (ow < OW) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \
int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \
int8x16_t row##i##2 = \
vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1));
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum[0], 1, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum[0], 2, 6, 7, 8);
CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum[1], 3, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum[1], 4, 6, 7, 8);
STORE_REMAIN(dst, (oh + 0), ow, OW, sum[0], remain);
STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh * 2;
size_t ow = 0_z;
#if MEGDNN_AARCH64
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \
int32x4x2_t tmp_row##i##_03 = vld2q_s32(tmp_sptr##i + 8); \
int32x4_t tmp_row##i = vld1q_s32(tmp_sptr##i + 16);
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \
int8x16_t row##i##2 = vreinterpretq_s8_s32( \
vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \
int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \
int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \
int8x16_t row##i##5 = vreinterpretq_s8_s32( \
vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1));
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
CALC_ONE_LINE_8_RESULT(sum, 0, 0, 1, 2);
CALC_ONE_LINE_8_RESULT(sum, 1, 3, 4, 5);
CALC_ONE_LINE_8_RESULT(sum, 2, 6, 7, 8);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum);
}
#endif
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[2] = {init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \
int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8);
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \
int8x16_t row##i##2 = \
vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1));
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum, 1, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum, 2, 6, 7, 8);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum);
}
if (OW > ow) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
int16x8_t sum[2] = {init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
#define cb(i) \
int32x4x2_t tmp_row##i##_0 = vld2q_s32(tmp_sptr##i); \
int32x4_t tmp_row##i##_1 = vld1q_s32(tmp_sptr##i + 8);
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
#define cb(i) \
int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \
int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \
int8x16_t row##i##2 = \
vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1));
UNROLL_CALL_NOWRAPPER(3, cb)
#undef cb
CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1, 2);
CALC_ONE_LINE_4_RESULT(sum, 1, 3, 4, 5);
CALC_ONE_LINE_4_RESULT(sum, 2, 6, 7, 8);
STORE_REMAIN(dst, oh, ow, OW, sum, remain);
}
}
#undef CALC_ONE_LINE_4_RESULT
#undef CALC_ONE_LINE_8_RESULT
#undef LOAD_5_SRC
}
#if MEGDNN_AARCH64
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16(
const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
INIT_SUM();
const int* fptr = reinterpret_cast<const int*>(filter);
int8x8_t kern[25];
#define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i));
UNROLL_CALL_NOWRAPPER(25, cb);
#undef cb
#define LOAD_5_SRC(_src, _id) \
do { \
int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \
int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 2); \
int32x4_t tmp_row = vld1q_s32(tmp_sptr##_id + 10); \
_src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \
_src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \
_src[2] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \
_src[3] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \
_src[4] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_23.val[0], tmp_row, 1)); \
} while (0);
#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \
_kid4) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]);
#define LOAD_10_SRC(_src, _id) \
do { \
int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \
int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 8); \
int32x4x2_t tmp_row = vld2q_s32(tmp_sptr##_id + 16); \
_src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \
_src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \
_src[2] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 1)); \
_src[3] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_01.val[1], tmp_row_23.val[1], 1)); \
_src[4] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 2)); \
_src[5] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \
_src[6] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \
_src[7] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 1)); \
_src[8] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_23.val[1], tmp_row.val[1], 1)); \
_src[9] = vreinterpretq_s8_s32( \
vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 2)); \
} while (0);
#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \
_kid4) \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[5]), kern[_kid0]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[5]), kern[_kid0]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[6]), kern[_kid1]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[6]), kern[_kid1]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[7]), kern[_kid2]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[7]), kern[_kid2]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[8]), kern[_kid3]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[8]), kern[_kid3]); \
_sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \
_sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); \
_sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[9]), kern[_kid4]); \
_sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[9]), kern[_kid4]);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh * 2;
size_t ow = 0_z;
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4;
const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4;
int16x8_t sum[2][4];
int8x16_t src[3][10];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb
LOAD_10_SRC(src[0], 0); // line0
LOAD_10_SRC(src[1], 1); // line1
CALC_ONE_LINE_8_RESULT(sum[0], src[0], 0, 1, 2, 3, 4);
LOAD_10_SRC(src[2], 2); // line2
CALC_ONE_LINE_8_RESULT(sum[0], src[1], 5, 6, 7, 8, 9);
LOAD_10_SRC(src[0], 3); // line3
CALC_ONE_LINE_8_RESULT(sum[0], src[2], 10, 11, 12, 13, 14);
CALC_ONE_LINE_8_RESULT(sum[1], src[2], 0, 1, 2, 3, 4);
LOAD_10_SRC(src[1], 4); // line4
CALC_ONE_LINE_8_RESULT(sum[0], src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_8_RESULT(sum[0], src[1], 20, 21, 22, 23, 24);
LOAD_10_SRC(src[2], 5); // line5
CALC_ONE_LINE_8_RESULT(sum[1], src[0], 5, 6, 7, 8, 9);
CALC_ONE_LINE_8_RESULT(sum[1], src[1], 10, 11, 12, 13, 14);
LOAD_10_SRC(src[0], 6); // line6
CALC_ONE_LINE_8_RESULT(sum[1], src[2], 15, 16, 17, 18, 19);
CALC_ONE_LINE_8_RESULT(sum[1], src[0], 20, 21, 22, 23, 24);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4;
const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
int8x16_t src[3][5];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb
LOAD_5_SRC(src[0], 0); // line0
LOAD_5_SRC(src[1], 1); // line1
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[2], 2); // line2
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[0], 3); // line3
CALC_ONE_LINE_4_RESULT(sum[0], src[2], 10, 11, 12, 13, 14);
CALC_ONE_LINE_4_RESULT(sum[1], src[2], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[1], 4); // line4
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[2], 5); // line5
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 20, 21, 22, 23, 24);
CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14);
LOAD_5_SRC(src[0], 6); // line6
CALC_ONE_LINE_4_RESULT(sum[1], src[2], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 20, 21, 22, 23, 24);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
if (ow < OW) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4;
const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
int8x16_t src[3][5];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb
LOAD_5_SRC(src[0], 0); // line0
LOAD_5_SRC(src[1], 1); // line1
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[2], 2); // line2
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[0], 3); // line3
CALC_ONE_LINE_4_RESULT(sum[0], src[2], 10, 11, 12, 13, 14);
CALC_ONE_LINE_4_RESULT(sum[1], src[2], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[1], 4); // line4
CALC_ONE_LINE_4_RESULT(sum[0], src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[2], 5); // line5
CALC_ONE_LINE_4_RESULT(sum[0], src[1], 20, 21, 22, 23, 24);
CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14);
LOAD_5_SRC(src[0], 6); // line6
CALC_ONE_LINE_4_RESULT(sum[1], src[2], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum[1], src[0], 20, 21, 22, 23, 24);
STORE_REMAIN(dst, oh, ow, OW, sum[0], remain);
STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh * 2;
size_t ow = 0_z;
for (; ow + 8 <= OW; ow += 8) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum};
int8x16_t src[3][10];
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
LOAD_10_SRC(src[0], 0); // line0
LOAD_10_SRC(src[1], 1); // line1
CALC_ONE_LINE_8_RESULT(sum, src[0], 0, 1, 2, 3, 4);
LOAD_10_SRC(src[2], 2); // line2
CALC_ONE_LINE_8_RESULT(sum, src[1], 5, 6, 7, 8, 9);
LOAD_10_SRC(src[0], 3); // line3
CALC_ONE_LINE_8_RESULT(sum, src[2], 10, 11, 12, 13, 14);
LOAD_10_SRC(src[1], 4); // line4
CALC_ONE_LINE_8_RESULT(sum, src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_8_RESULT(sum, src[1], 20, 21, 22, 23, 24);
STORE_1_LINE_RESULT(dst, oh, ow, OW, sum);
}
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2] = {init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
int8x16_t src[3][5];
LOAD_5_SRC(src[0], 0); // line0
LOAD_5_SRC(src[1], 1); // line1
CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[2], 2); // line2
CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[0], 3); // line3
CALC_ONE_LINE_4_RESULT(sum, src[2], 10, 11, 12, 13, 14);
LOAD_5_SRC(src[1], 4); // line4
CALC_ONE_LINE_4_RESULT(sum, src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum, src[1], 20, 21, 22, 23, 24);
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum);
}
if (OW > ow) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2] = {init_sum, init_sum};
#define cb(i) \
const int32_t* tmp_sptr##i = reinterpret_cast<const int32_t*>(sptr##i);
UNROLL_CALL_NOWRAPPER(5, cb);
#undef cb
int8x16_t src[3][5];
LOAD_5_SRC(src[0], 0); // line0
LOAD_5_SRC(src[1], 1); // line1
CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4);
LOAD_5_SRC(src[2], 2); // line2
CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9);
LOAD_5_SRC(src[0], 3); // line3
CALC_ONE_LINE_4_RESULT(sum, src[2], 10, 11, 12, 13, 14);
LOAD_5_SRC(src[1], 4); // line4
CALC_ONE_LINE_4_RESULT(sum, src[0], 15, 16, 17, 18, 19);
CALC_ONE_LINE_4_RESULT(sum, src[1], 20, 21, 22, 23, 24);
STORE_REMAIN(dst, oh, ow, OW, sum, remain);
}
}
}
#undef CALC_ONE_LINE_8_RESULT
#undef CALC_ONE_LINE_4_RESULT
#undef LOAD_10_SRC
#undef LOAD_5_SRC
#elif MEGDNN_ARMV7
template <BiasMode bias_mode>
void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16(
const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW) {
MEGDNN_MARK_USED_VAR(IH);
const int16_t* __restrict bptr = bias;
const int32_t* tmp_filter = reinterpret_cast<const int32_t*>(filter);
INIT_SUM();
int8x8_t kern0[3], kern1[3], kern2[3], kern3[3], kern4[3];
int32x2_t tmp_kern = vdup_n_s32(tmp_filter[4]);
tmp_kern = vset_lane_s32(0,tmp_kern,1);
kern0[0] = vld1_s8(filter);
kern0[1] = vld1_s8(filter + 8);
kern0[2] = vreinterpret_s8_s32(tmp_kern);
tmp_kern = vdup_n_s32(tmp_filter[9]);
tmp_kern = vset_lane_s32(0,tmp_kern,1);
kern1[0] = vld1_s8(filter + 20);
kern1[1] = vld1_s8(filter + 28);
kern1[2] = vreinterpret_s8_s32(tmp_kern);
tmp_kern = vdup_n_s32(tmp_filter[14]);
tmp_kern = vset_lane_s32(0,tmp_kern,1);
kern2[0] = vld1_s8(filter + 40);
kern2[1] = vld1_s8(filter + 48);
kern2[2] = vreinterpret_s8_s32(tmp_kern);
tmp_kern = vdup_n_s32(tmp_filter[19]);
tmp_kern = vset_lane_s32(0,tmp_kern,1);
kern3[0] = vld1_s8(filter + 60);
kern3[1] = vld1_s8(filter + 68);
kern3[2] = vreinterpret_s8_s32(tmp_kern);
tmp_kern = vdup_n_s32(tmp_filter[24]);
tmp_kern = vset_lane_s32(0,tmp_kern,1);
kern4[0] = vld1_s8(filter + 80);
kern4[1] = vld1_s8(filter + 88);
kern4[2] = vreinterpret_s8_s32(tmp_kern);
#define LOAD_3_SRC_ARRAY(_src,_sptr)\
_src[0] = vld1q_s8(_sptr);/*0 1 2 3 */\
_src[1] = vld1q_s8(_sptr + 16);/*4 5 6 7 */\
_src[2] = vld1q_s8(_sptr + 32);/*8 9 10 11*/
#define CALC_ONE_LINE(_src, _kern, _sum) \
tmpsum0 = vmull_s8(vget_low_s8(_src[0]), _kern[0]); /*01*/ \
tmpsum1 = vmull_s8(vget_high_s8(_src[0]), _kern[0]); /*23*/ \
tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[0]), _kern[1]); /*23*/ \
tmpsum1 = vmlal_s8(tmpsum1, vget_low_s8(_src[1]), _kern[1]); /*45*/ \
tmpsum0 = vmlal_s8(tmpsum0, vget_low_s8(_src[1]), _kern[2]); /*4*/ \
tmpsum1 = vmlal_s8(tmpsum1, vget_high_s8(_src[1]), _kern[2]); /*6*/ \
res0 = vadd_s16(vget_low_s16(tmpsum0), vget_high_s16(tmpsum0)); \
res1 = vadd_s16(vget_low_s16(tmpsum1), vget_high_s16(tmpsum1)); \
_sum[0] = vaddq_s16(_sum[0], vcombine_s16(res0, res1)); \
\
tmpsum0 = vmull_s8(vget_low_s8(_src[1]), _kern[0]); /*45*/ \
tmpsum1 = vmull_s8(vget_high_s8(_src[1]), _kern[0]); /*67*/ \
tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[1]), _kern[1]); /*67*/ \
tmpsum1 = vmlal_s8(tmpsum1, vget_low_s8(_src[2]), _kern[1]); /*89*/ \
tmpsum0 = vmlal_s8(tmpsum0, vget_low_s8(_src[2]), _kern[2]); /*8*/ \
tmpsum1 = vmlal_s8(tmpsum1, vget_high_s8(_src[2]), _kern[2]); /*10*/ \
res0 = vadd_s16(vget_low_s16(tmpsum0), vget_high_s16(tmpsum0)); \
res1 = vadd_s16(vget_low_s16(tmpsum1), vget_high_s16(tmpsum1)); \
_sum[1] = vaddq_s16(_sum[1], vcombine_s16(res0, res1));
#define CALC_8_RESULT() \
LOAD_3_SRC_ARRAY(src0, sptr0); \
LOAD_3_SRC_ARRAY(src1, sptr1); \
CALC_ONE_LINE(src0, kern0, sum[0]); \
\
LOAD_3_SRC_ARRAY(src0, sptr2); \
CALC_ONE_LINE(src1, kern1, sum[0]); \
\
LOAD_3_SRC_ARRAY(src1, sptr3); \
CALC_ONE_LINE(src0, kern2, sum[0]); \
CALC_ONE_LINE(src0, kern0, sum[1]); \
\
LOAD_3_SRC_ARRAY(src0, sptr4); \
CALC_ONE_LINE(src1, kern3, sum[0]); \
CALC_ONE_LINE(src1, kern1, sum[1]); \
\
LOAD_3_SRC_ARRAY(src1, sptr5); \
CALC_ONE_LINE(src0, kern4, sum[0]); \
CALC_ONE_LINE(src0, kern2, sum[1]); \
\
LOAD_3_SRC_ARRAY(src0, sptr6); \
CALC_ONE_LINE(src1, kern3, sum[1]); \
CALC_ONE_LINE(src0, kern4, sum[1]);
#define CALC_4_RESULT() \
LOAD_3_SRC_ARRAY(src0, sptr0); \
LOAD_3_SRC_ARRAY(src1, sptr1); \
CALC_ONE_LINE(src0, kern0, sum); \
\
LOAD_3_SRC_ARRAY(src0, sptr2); \
CALC_ONE_LINE(src1, kern1, sum); \
\
LOAD_3_SRC_ARRAY(src1, sptr3); \
CALC_ONE_LINE(src0, kern2, sum); \
\
LOAD_3_SRC_ARRAY(src0, sptr4); \
CALC_ONE_LINE(src1, kern3, sum); \
CALC_ONE_LINE(src0, kern4, sum);
size_t oh = 0_z;
for (; oh + 2 <= OH; oh += 2) {
size_t ih = oh * 2;
size_t ow = 0_z;
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4;
const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
int8x16_t src0[3], src1[3];
int16x8_t tmpsum0, tmpsum1;
int16x4_t res0, res1;
CALC_8_RESULT();
STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]);
STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]);
}
if (ow < OW) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
const int8_t* __restrict sptr5 = src + (ih + 5) * IW * 4 + iw * 4;
const int8_t* __restrict sptr6 = src + (ih + 6) * IW * 4 + iw * 4;
int16x8_t sum[2][2];
#define cb(j) \
sum[0][j] = init_sum; \
sum[1][j] = init_sum;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
int8x16_t src0[3], src1[3];
int16x8_t tmpsum0, tmpsum1;
int16x4_t res0, res1;
CALC_8_RESULT();
STORE_REMAIN(dst, oh, ow, OW, sum[0],remain);
STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1],remain);
}
}
for (; oh < OH; oh++) {
size_t ih = oh * 2;
size_t ow = 0_z;
for (; ow + 4 <= OW; ow += 4) {
size_t iw = ow * 2;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2]={init_sum,init_sum};
int8x16_t src0[3], src1[3];
int16x8_t tmpsum0, tmpsum1;
int16x4_t res0, res1;
CALC_4_RESULT();
STORE_1_LINE_4_RESULT(dst, oh,ow, OW, sum);
}
if (OW > ow) {
size_t iw = ow * 2;
size_t remain = OW - ow;
const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4;
int16x8_t sum[2] = {init_sum, init_sum};
int8x16_t src0[3], src1[3];
int16x8_t tmpsum0, tmpsum1;
int16x4_t res0, res1;
CALC_4_RESULT();
STORE_REMAIN(dst, oh, ow, OW, sum, remain);
}
}
}
#undef CALC_ONE_LINE
#undef CALC_4_RESULT
#undef CALC_8_RESULT
#undef LOAD_3_SRC_ARRAY
#endif
#undef INIT_SUM
#undef STORE_1_LINE_RESULT
#undef STORE_1_LINE_4_RESULT
#undef STORE_REMAIN
#define INSTANTIATION(stride, i, bias) \
template void channel_wise_nchw44_8x8x16:: \
direct_##stride##_##i##x##i##_int8x8x16<bias>( \
const int8_t*, const int8_t*, const int16_t*, void*, \
const size_t, const size_t, const size_t, const size_t);
#define FOR_OP(stride, i, bias) INSTANTIATION(stride, i, bias)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5)
#define FOR_STRIDE \
FOR_FILTER(stride1) \
FOR_FILTER(stride2)
FOR_STRIDE
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44 {
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;
using conv_fun = std::function<void(const WorkspaceBundle& bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index)>;
namespace stride1 {
bool is_available(const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
template <bool quantized, size_t filter, BiasMode bias_mode, typename Op>
void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
SmallVector<ConvBiasImpl::NCBKern> get_kimpls(const NCBKernSizeParam& param);
} // namespace stride1
namespace stride2 {
bool is_available(const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
template <bool quantized, size_t filter, BiasMode bias_mode, typename Op>
void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
SmallVector<ConvBiasImpl::NCBKern> get_kimpls(const NCBKernSizeParam& param);
} // namespace stride2
} // namespace direct_int8_stride1
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h"
#include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
using namespace channel_wise_nchw44_8x8x16;
namespace {
void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t& IH2, size_t& IW2) {
auto&& fm = param.filter_meta;
auto SW = fm.stride[1];
auto OH = param.osz[0];
auto OW = param.osz[1];
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
size_t OW2 = (OW + 3) & ~3;
IH2 = SW * OH + FH - SW;
IW2 = SW * OW2 + FW - SW;
}
} // namespace
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1)
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2)
WorkspaceBundle stride1::get_bundle(
const ConvBiasImpl::NCBKernSizeParam& param) {
size_t nr_threads = param.nr_threads;
size_t IH2, IW2;
get_rectified_size(param, IH2, IW2);
constexpr size_t pack_ic_size = 4_z;
//! The extra 16B is used to void ivalid read in kernel compute
size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16;
SmallVector<size_t> sizes(nr_threads, src_size);
return {nullptr, sizes};
}
//! compute one output channel
template <size_t filter, BiasMode bias_mode>
void stride1::do_conv_kern(const WorkspaceBundle& bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
constexpr size_t pack_group_size = 4_z;
constexpr size_t pack_ic_size = 4_z;
size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
int8_t* padding_src = static_cast<int8_t*>(bundle.get(thread_id));
const int8_t* sptr =
kern_param.src<dt_int8>(batch_id, group_id, 0, pack_group_size);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id, pack_group_size);
void* dst = kern_param.dst<void>(batch_id, group_id, 0, pack_group_size);
const int16_t* bptr =
kern_param.bias<dt_int16>(batch_id, group_id, 0, pack_group_size);
//! copy in case of illegal read src when padding is zero
std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size);
rep(ih, IH) {
std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size,
sptr + ih * IW * pack_ic_size,
sizeof(int8_t) * IW * pack_ic_size);
}
sptr = padding_src;
#define KERN(_size) \
direct_stride1_##_size##x##_size##_int8x8x16<bias_mode>( \
sptr, fptr, bptr, dst, IH2, IW2, OH, OW);
DISPATCH_FILTER_CHANNEL_WISE(filter, KERN);
#undef KERN
}
SmallVector<ConvBiasImpl::NCBKern> stride1::get_kimpls(
const NCBKernSizeParam& param) {
auto fm = param.filter_meta;
size_t N = param.n;
size_t group = fm.group / 4;
megdnn_assert(fm.group % 4 == 0,
"nchw44 channel wise conv with group is not times of 4");
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
#define DO_CONV_KERN_FUN(filter, bias_mode) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1, \
midout_iv(#filter #bias_mode##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(i, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(i, bias_mode) \
break; \
default: \
megdnn_assert(0, "only support NonlineMode::IDENTITY"); \
break; \
}
#define GET_BIAS_MODE_PARAM(i) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(i, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0, \
"only support BiasMode::NO_BIAS and " \
"BiasMode::BROADCAST_CHANNEL_BIAS"); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
default: \
megdnn_assert(0, "only support filtersize 2x2 3x3 5x5"); \
break; \
}
DISPATCH_CONV_KERN();
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
auto exec_one_group = [wbundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
do_conv_fun(wbundle, kern_param, ncb_index);
};
ret_kerns.push_back({exec_one_group, {N, group}});
return ret_kerns;
#undef DO_CONV_KERN_FUN
}
WorkspaceBundle stride2::get_bundle(
const ConvBiasImpl::NCBKernSizeParam& param) {
size_t nr_threads = param.nr_threads;
size_t IH2, IW2;
get_rectified_size(param, IH2, IW2);
constexpr size_t pack_ic_size = 4_z;
//! The extra 16B is used to void ivalid read in kernel compute
size_t src_size = IH2 * IW2 * pack_ic_size * sizeof(int8_t) + 16;
SmallVector<size_t> sizes(nr_threads, src_size);
return {nullptr, sizes};
}
//! compute one output channel
template <size_t filter, BiasMode bias_mode>
void stride2::do_conv_kern(const WorkspaceBundle& bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
constexpr size_t pack_group_size = 4_z;
constexpr size_t pack_ic_size = 4_z;
size_t thread_id = ncb_index.thread_id, batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
int8_t* padding_src = static_cast<int8_t*>(bundle.get(thread_id));
const int8_t* sptr =
kern_param.src<dt_int8>(batch_id, group_id, 0, pack_group_size);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id, pack_group_size);
void* dst = kern_param.dst<void>(batch_id, group_id, 0, pack_group_size);
const int16_t* bptr =
kern_param.bias<dt_int16>(batch_id, group_id, 0, pack_group_size);
//! copy in case of illegal read src when padding is zero
std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size);
rep(ih, IH) {
std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size,
sptr + ih * IW * pack_ic_size,
sizeof(int8_t) * IW * pack_ic_size);
}
sptr = padding_src;
#define KERN(_size) \
direct_stride2_##_size##x##_size##_int8x8x16<bias_mode>( \
sptr, fptr, bptr, dst, IH2, IW2, OH, OW);
DISPATCH_FILTER_CHANNEL_WISE(filter, KERN);
#undef KERN
}
SmallVector<ConvBiasImpl::NCBKern> stride2::get_kimpls(
const NCBKernSizeParam& param) {
auto fm = param.filter_meta;
size_t N = param.n;
size_t group = fm.group / 4;
megdnn_assert(fm.group % 4 == 0,
"nchw44 channel wise conv with group is not times of 4");
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
#define DO_CONV_KERN_FUN(filter, bias_mode) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2, \
midout_iv(#filter #bias_mode##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode>; \
} \
MIDOUT_END();
DISPATCH_CONV_KERN();
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
auto exec_one_group = [wbundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
do_conv_fun(wbundle, kern_param, ncb_index);
};
ret_kerns.push_back({exec_one_group, {N, group}});
return ret_kerns;
#undef DISPATCH_CONV_KERN
#undef GET_BIAS_MODE_PARAM
#undef GET_OP_PARAM
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_8x8x16 {
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;
using conv_fun = std::function<void(const WorkspaceBundle& bundle,
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index)>;
namespace stride1 {
bool is_available(const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
template <size_t filter, BiasMode bias_mode>
void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
SmallVector<ConvBiasImpl::NCBKern> get_kimpls(const NCBKernSizeParam& param);
} // namespace stride1
namespace stride2 {
bool is_available(const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
template <size_t filter, BiasMode bias_mode>
void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index);
SmallVector<ConvBiasImpl::NCBKern> get_kimpls(const NCBKernSizeParam& param);
} // namespace stride2
} // namespace direct_int8_stride1
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -48,6 +48,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride1 s8_direct_stride1;
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectStride1 ds8_direct_stride1;
......@@ -95,6 +96,7 @@ public:
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1);
direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
......
......@@ -54,6 +54,7 @@ private:
class AlgoS8ChanWiseStride1NCHW44;
class AlgoS8ChanWiseStride2NCHW44;
class AlgoS8x8x16ChanWiseStride1Stride2NCHW44;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoFP16WinogradF23;
......
......@@ -558,6 +558,142 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name,
}
}
std::vector<conv_bias::TestArg> get_nchw44_channel_wise_benchmark_args(
std::vector<size_t> kernel, size_t stride, bool no_bias,
bool no_nonlinemode, bool no_full_bias) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
size_t stride, NLMode nlmode, bool pad) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
if (pad) {
param.pad_h = kernel / 2;
param.pad_w = kernel / 2;
} else {
param.pad_h = 0;
param.pad_w = 0;
}
param.nonlineMode = nlmode;
param.format = param::ConvBias::Format::NCHW44;
param.sparse = param::ConvBias::Sparse::GROUP;
args.emplace_back(param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{});
if (!no_bias) {
args.emplace_back(param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{1, group, 1, 1, 4});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{n, group,
(h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1,
4});
}
};
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
if (!no_nonlinemode) {
nonlinemode.emplace_back(NLMode::RELU);
nonlinemode.emplace_back(NLMode::H_SWISH);
}
for (size_t n : {1}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 128}) {
for (size_t size : {40,89,100,200}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
}
}
}
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 4, 8, 16, 32, 64, 128}) {
for (size_t size : {40, 89, 100}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
}
}
}
}
}
}
return args;
}
void BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32(const char* algo_name0,
const char* algo_name1, Handle* handle,
size_t kernel,size_t stride = 1, size_t pack_size = 1) {
auto args = get_nchw44_channel_wise_benchmark_args({2, 3, 5}, stride, false, true, true);
using namespace conv_bias;
constexpr size_t RUN = 10;
Benchmarker<ConvBias> benchmark(handle);
benchmark.set_display(false);
benchmark.set_times(RUN);
benchmark.set_dtype(0, dtype::Int8());
benchmark.set_dtype(1, dtype::Int8());
benchmark.set_dtype(2, dtype::Int32());
benchmark.set_dtype(4, dtype::Int32());
Benchmarker<ConvBias> benchmark_algo1(handle);
benchmark_algo1.set_display(false);
benchmark_algo1.set_times(RUN);
benchmark_algo1.set_dtype(0, dtype::Int8());
benchmark_algo1.set_dtype(1, dtype::Int8());
benchmark_algo1.set_dtype(2, dtype::Int16());
benchmark_algo1.set_dtype(4, dtype::Int16());
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Float32()},
{arg.filter, dtype::Float32()},
{arg.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 * pack_size/
(1024 * 1024 * 1024) * 1e3;
benchmark.set_param(arg.param);
auto used = algo_benchmark<ConvBias>(benchmark,
{arg.src, arg.filter, {}, {}, {}},
algo_name0) /
RUN;
arg.param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
arg.param.format = param::ConvBias::Format::NCHW44;
benchmark_algo1.set_param(arg.param);
auto used_algo1 =
algo_benchmark<ConvBias>(
benchmark_algo1,
{arg.src, arg.filter, {}, {}, {}},
algo_name1) /
RUN;
printf("%s %s: normal: %f ms %f Gflops 8x8x16: %f ms %f GFlops "
"speedup: "
"%f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
used, computations / used, used_algo1,
computations / used_algo1, used / used_algo1);
}
}
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
printf("=========================compare "
......@@ -579,6 +715,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) {
}
#endif
TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) {
BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44",
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44",
handle(), 3,1,4);
}
TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE2) {
BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD2_NCHW44",
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44",
handle(), 3,2, 4);
}
TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONVBIAS_QUANTIZED) {
constexpr size_t RUNS = 50;
param::ConvBias param;
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/dtype.h"
#include "test/arm_common/fixture.h"
#include "test/common/benchmarker.h"
#include "test/common/conv_bias.h"
......@@ -475,6 +476,36 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}
TEST_F(ARM_COMMON,
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
Checker<ConvBias> checker(handle());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int16());
checker.set_dtype(4, dtype::Int16());
auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
for (auto&& arg : args) {
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
}
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
Checker<ConvBias> checker(handle());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int16());
checker.set_dtype(4, dtype::Int16());
auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
for (auto&& arg : args) {
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
}
}
/********************************qint8 direct******************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
......
......@@ -1706,6 +1706,77 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CHANNEL_WISE_INT8_INT8_INT16_STRIDE1) {
constexpr size_t RUNS = 50;
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::GROUP;
param.format = param::ConvBias::Format::NCHW44;
std::vector<std::pair<SmallVector<TensorShape>, float>>
shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS,
size_t P) {
size_t group = IC;
size_t OC = IC;
size_t S = 1;
SmallVector<TensorShape> shapes{
{N, IC, H, W, 4},
{group, 1, 1, FS, FS, 4},
{1, OC, 1, 1, 4},
{},
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1, 4}};
TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1,
(W + 2 * P - FS) / S + 1, 4};
float computations =
((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};
bench_case(1, 128, 200, 200, 3, 1);
bench_case(1, 128, 128, 128, 3, 1);
bench_case(1, 128, 100, 100, 3, 1);
bench_case(1, 128, 80, 80, 3, 1);
bench_case(1, 128, 56, 56, 3, 1);
bench_case(1, 128, 28, 28, 3, 1);
bench_case(1, 128, 14, 14, 3, 1);
bench_case(1, 64, 200, 200, 3, 1);
bench_case(1, 64, 128, 128, 3, 1);
bench_case(1, 64, 100, 100, 3, 1);
bench_case(1, 64, 80, 80, 3, 1);
bench_case(1, 64, 56, 56, 3, 1);
bench_case(1, 64, 28, 28, 3, 1);
bench_case(1, 64, 14, 14, 3, 1);
bench_case(1, 32, 200, 200, 3, 1);
bench_case(1, 32, 128, 128, 3, 1);
bench_case(1, 32, 100, 100, 3, 1);
bench_case(1, 32, 80, 80, 3, 1);
bench_case(1, 32, 56, 56, 3, 1);
bench_case(1, 32, 28, 28, 3, 1);
bench_case(1, 32, 14, 14, 3, 1);
std::string algo_name = "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44";
printf("Benchmarker S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44 algo\n");
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(),
dtype::Int16(), dtype::Int16()};
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {7}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) {
constexpr size_t RUNS = 50;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册