提交 3117bfb7 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): nchw44 direct int8 support 8832

GitOrigin-RevId: 696fa05d943b28fcec3a236bb8518fb255eae9db
上级 4e0c9ad3
......@@ -38,23 +38,6 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase {
public:
AlgoS8DirectStride1NCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
bool is_preferred(megdnn::fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
bool m_large_group;
......@@ -74,11 +57,11 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
public:
AlgoS8DirectStride2NCHW44() {}
AlgoS8DirectNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; }
const char* name() const override { return "S8_NCHW44_DIRECT"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
......@@ -245,8 +228,8 @@ private:
//=======================input int8 compute fp32 output int8============
class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase {
public:
AlgoS8CF32WinogradF23_4x4_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
AlgoS8CF32WinogradF23_4x4_NCHW44(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
......@@ -277,7 +260,7 @@ private:
class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase {
public:
AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
uint32_t tile_size)
uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
......
......@@ -36,26 +36,6 @@ KERN(stride2, 7, nchw)
#undef KERN
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op, int remain_w> \
void conv_direct_##stride##_##i##x##i##_int8_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);
KERN(stride1, 2, nchw44)
KERN(stride1, 3, nchw44)
KERN(stride1, 5, nchw44)
KERN(stride1, 7, nchw44)
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter);
void nchw44_pack_src(const int8_t* src, int8_t* dst, int length);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -13,6 +13,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
......@@ -25,28 +26,19 @@ using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2)
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44)
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2) {
auto&& fm = param.filter_meta;
size_t SW = fm.stride[1];
size_t IH = param.isz[0];
size_t IW = param.isz[1];
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
int ih = param.isz[0];
int iw = param.isz[1];
int ph = fm.padding[0];
int pw = fm.padding[1];
OH2 = OH;
OW2 = (OW + 7) & ~7;
IH2 = SW * OH + FH - SW;
IW2 = SW * OW2 + FW - SW;
// Because stride is 2, sometimes IW == IW2+1. Do a max update to
// handle this case.
IH2 = std::max(IH2, IH);
IW2 = std::max(IW2, IW);
ih2 = ih + ph * 2;
iw2 = iw + pw * 2;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
constexpr size_t src_expand = 4;
......@@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
size_t OC = fm.ocpg;
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(param, IH2, IW2, OH2, OW2);
int IH2, IW2;
get_rectified_size(param, IH2, IW2);
if (group == 1) {
size_t src_size =
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand;
......@@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
int IH = kern_param.isz[0];
int IW = kern_param.isz[1];
int IC = kern_param.filter_meta.icpg;
int PH = kern_param.filter_meta.padding[0];
int PW = kern_param.filter_meta.padding[1];
int GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
size_t padding_group_size = IH2 * IW2 * IC;
int IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
int padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
constexpr int pack_ic = 4;
......@@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle,
size_t group_id = ncb_index.ndrange_id[1];
size_t group_pack_size = 1;
int nr_pad_h = PH * IW2 * pack_ic * expend_element;
int nr_pad_w = PW * pack_ic * expend_element;
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element;
int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0
? nr_pad_w + over_pad
: (IW2 - IW - PW) * pack_ic * expend_element;
int col_last_pad =
((int)IH2 - (int)IH - 2 * (int)PH) >= 0
? nr_pad_h
: (IH2 - IH - PH) * IW2 * pack_ic * expend_element;
int nr_pad_h = PH * IW2 * pack_ic * expend_element;
int row_last_pad = (IW2 - IW - PW) * pack_ic * expend_element;
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * expend_element;
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>(
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic));
......@@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
conv_bias::nchw44_pack_src(sptr, sptr_base, IW);
nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * expend_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
......@@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain>
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain,
typename DstType, int stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
......@@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle,
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
int IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op op = Op(1.0f, 4.0f);
Op op(1.f, 4.f);
if (need_post_process) {
float scale_bias =
kern_param.bias_type.param<dtype::QuantizedS32>().scale;
......@@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle,
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC;
void* dst = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(
kern_param.dst<void>(batch_id, group_id)) +
oc_idx * OH * OW);
DstType* dst = reinterpret_cast<DstType*>(
kern_param.dst<void>(batch_id, group_id, oc_idx));
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
conv_bias::nchw44_pack_filter(fptr, packed_weight,
oc_block / 4 * IC / 4 * FH * FW);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, \
IH2, IW2, OH, OW, op)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV)
#undef KERN1_NCHW44_CONV
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW);
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>(
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst),
oc_block, IC, IH2, IW2, OH, OW, op);
}
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable(
bool ConvBiasImpl::AlgoS8DirectNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool avaible = //! src and filter are qint8, dst is qint8 or qint32
const int fh = fm.spatial[0];
const int fw = fm.spatial[1];
const int oc = fm.ocpg;
const int ic = fm.icpg;
const bool avaible = //! src and filter are qint8, dst is qint8 or qint32
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip &&
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7) &&
param.bias_mode != BiasMode::BIAS;
return avaible;
}
bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred(
bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr,
const NCBKernSizeParam& param) const {
// TODO: benchmark and fix
......@@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred(
return false;
}
size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace(
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns(
ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
......@@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns(
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
int ow_remain = OW % 8;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \
} \
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \
stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode, remain_w) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \
if (need_post_process) { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0, "no supported noline mode"); \
break; \
} \
} else { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \
remain_w, NoneOp<dt_int32>) \
break; \
default: \
megdnn_assert( \
0, \
"only support IDENTITY mode when dst is not qint8"); \
break; \
} \
}
#define GET_REMAIN_W_PARAM(filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(stride, filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(stride, filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(stride, filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(stride, filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(stride, filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(stride, filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(stride, filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(stride, filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, \
BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv",
param.filter_meta.stride[0])
.c_str());
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[1][8];
int8x16_t weight[1][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride1Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 2, c_dim, DstType> {
static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*,
int, int, int, const Op&, int) {
megdnn_throw("no impl");
}
};
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
//! TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
/**
origin weight shape <oc/4, ic/4, fh, fw, 4, 4>
packed weight shape <oc/4, ic/4, fh, fw, 16>
example: (format like weight<oc, ic>)
origin
<0, 0> <1, 0> <2, 0> <3, 0>
<0, 1> <1, 1> <2, 1> <3, 1>
<0, 2> <1, 2> <2, 2> <3, 2>
<0, 3> <1, 3> <2, 3> <3, 3>
packed
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
**/
static inline void nchw44_pack_filter(const int8_t* src, int8_t* dst,
int length) {
static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15,
12, 8, 5, 1, 14, 10, 7, 3};
constexpr int simd_len = 16;
uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer);
for (int i = 0; i < length; i++) {
int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx);
vst1q_s8(dst + i * simd_len, result);
}
}
/**
origin src shape <n, ic/4, h, w, 4>
packed src shape <n, ic/4, h, w, 16>
example: (format like <ic>)
origin
<0> <0> <0> <0>
packed
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3>
---------------------------------------------------------------------
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0>
**/
static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) {
static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3,
3, 2, 1, 0, 3, 2, 1, 0};
constexpr int pack_ic = 4;
constexpr int simd_len = 16;
uint8x16_t src_idx = vld1q_u8(src_idx_buffer);
for (int i = 0; i < length; i++) {
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx);
vst1q_s8(dst + i * simd_len, result);
}
}
template <BiasMode bias_mode, typename Op, int remain_w, typename DstType>
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * oc_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size,
2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, remain_w,
filter_size, 2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size,
1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, remain_w,
filter_size, 1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType>
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const int ld_dst_oc = oh * ow * oc_step;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride1Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
KerNeonDirectStride1Int8<bias_mode, Op, remain_w, filter_size,
1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
}
}
}
/////////////////////stride 2/////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride2Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 2, c_dim, DstType> {
static void impl(const int8_t*, const int8_t*, const int32_t*, DstType*,
int, int, int, const Op&, int) {
megdnn_throw("no impl");
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]);
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]);
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]);
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]);
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]);
src[7] = vld1q_s8(src_ic_0_3 + 17 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16);
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, typename DstType>
void conv_direct_stride2_2x2_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oh * ow * oc_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, ow_step,
filter_size, 2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, remain_w,
filter_size, 2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, ow_step,
filter_size, 1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, remain_w,
filter_size, 1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType>
void conv_direct_stride2_int8_nchw44_kern(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const int ld_dst_oc = oh * ow * oc_step;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride2Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
KerNeonDirectStride2Int8<bias_mode, Op, remain_w, filter_size,
1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType, int stride>
struct ConvDirectInt8Nchw44Choose {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, remain_w, filter_size, DstType,
1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
if (filter_size == 2) {
conv_direct_stride1_2x2_int8_nchw44<bias_mode, Op, remain_w,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
} else {
conv_direct_stride1_int8_nchw44_kern<bias_mode, Op, remain_w,
filter_size, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, remain_w, filter_size, DstType,
2> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
if (filter_size == 2) {
conv_direct_stride2_2x2_int8_nchw44<bias_mode, Op, remain_w,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
} else {
conv_direct_stride2_int8_nchw44_kern<bias_mode, Op, remain_w,
filter_size, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
typename DstType, int stride>
void conv_direct_int8_nchw44(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
ConvDirectInt8Nchw44Choose<bias_mode, Op, remain_w, filter_size, DstType,
stride>::impl(src, filter, bias, temp, dst, oc,
ic, ih, iw, oh, ow, op);
}
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1)
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
auto&& fm = param.filter_meta;
auto SW = fm.stride[1];
auto OH = param.osz[0];
auto OW = param.osz[1];
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
OH2 = OH;
OW2 = (OW + 7) & ~7;
IH2 = SW * OH + FH - SW;
IW2 = SW * OW2 + FW - SW;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
constexpr size_t src_expand = 4;
auto&& fm = param.filter_meta;
size_t group = fm.group;
size_t batch = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(param, IH2, IW2, OH2, OW2);
if (group == 1) {
size_t src_size =
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t);
return {nullptr, {src_size, weight_size}};
} else {
size_t src_size =
param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t);
return {nullptr, {src_size, weight_size}};
}
};
static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
size_t padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
constexpr int pack_ic = 4;
constexpr int expend_element = 4;
// TODO: block dim is better to get from arg
size_t workspace_ic_block = 4;
size_t workspace_batch_id = workspace_ids[0];
size_t workspace_group_id = workspace_ids[1];
size_t workspace_ic_id = workspace_ids[2];
size_t workspace_ic = workspace_ic_id * workspace_ic_block;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t group_pack_size = 1;
int nr_pad_h = PH * IW2 * pack_ic * expend_element;
int nr_pad_w = PW * pack_ic * expend_element;
int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element;
//! copy to sptr_base to eliminate padding effect
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>(
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic));
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * GROUP * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * IH2 * IW2) *
expend_element;
size_t nr_ic = workspace_ic_block;
if (GROUP > 1) {
nr_ic = IC;
}
rep_step(ic_idx, nr_ic, pack_ic) {
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t));
sptr_base += nr_pad_h;
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
conv_bias::nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * expend_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t));
sptr_base += nr_pad_w + over_pad;
}
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t));
sptr_base += nr_pad_h;
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op op = Op(1.0f, 4.0f);
if (need_post_process) {
float scale_bias =
kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
size_t padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);
constexpr size_t pack_c = 4;
constexpr size_t src_expand_size = 4;
const size_t workspace_batch_id = workspace_ids[0];
const size_t workspace_group_id = workspace_ids[1];
const size_t batch_id = ncb_index.ndrange_id[0];
const size_t group_id = ncb_index.ndrange_id[1];
const size_t oc_id = ncb_index.ndrange_id[2];
const size_t oc_block_num = ncb_range[2];
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num);
size_t oc_block = nr_pack_per_step * pack_c;
const size_t oc_idx = oc_id * oc_block;
if (oc_id == (oc_block_num - 1)) {
oc_block = OC - oc_id * nr_pack_per_step * pack_c;
}
megdnn_assert(oc_block % pack_c == 0,
"oc must be devisible by 4, but oc = %zu", oc_block);
const int8_t* sptr =
static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * GROUP * padding_group_size * src_expand_size +
workspace_group_id * padding_group_size * src_expand_size;
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC;
void* dst = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(
kern_param.dst<void>(batch_id, group_id)) +
oc_idx * OH * OW);
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
conv_bias::nchw44_pack_filter(fptr, packed_weight,
oc_block / 4 * IC / 4 * FH * FW);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, \
IH2, IW2, OH, OW, op)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV)
#undef KERN1_NCHW44_CONV
}
/* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool avaible = //! src and filter are qint8, dst is qint8 or qint32
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
param.bias_mode != BiasMode::BIAS;
return avaible;
}
bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr,
const NCBKernSizeParam& param) const {
// TODO: benchmark and fix
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr);
MEGDNN_MARK_USED_VAR(param);
return false;
}
size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t OW = param.osz[1];
size_t group = fm.group;
size_t fh = fm.spatial[0];
size_t fw = fm.spatial[1];
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
int ow_remain = OW % 8;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode, remain_w) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_REMAIN_W_PARAM(filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;
constexpr size_t pack_oc = 4;
size_t oc_step = pack_oc;
if (fh == 2 && fw == 2 && OC >= 8) {
oc_step = 8;
}
if (group == 1) {
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
constexpr size_t pack_ic = 4;
ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}});
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
} else {
CpuNDRange ncb_range = {N, group, 1};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0});
do_conv_fun(bundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0}, ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
}
return ret_kerns;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace {
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]);
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]);
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]);
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]);
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]);
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
} // namespace
/**
origin weight shape <oc/4, ic/4, fh, fw, 4, 4>
packed weight shape <oc/4, ic/4, fh, fw, 16>
example: (format like weight<oc, ic>)
origin
<0, 0> <1, 0> <2, 0> <3, 0>
<0, 1> <1, 1> <2, 1> <3, 1>
<0, 2> <1, 2> <2, 2> <3, 2>
<0, 3> <1, 3> <2, 3> <3, 3>
packed
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
**/
void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) {
static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15,
12, 8, 5, 1, 14, 10, 7, 3};
constexpr int simd_len = 16;
uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer);
for (int i = 0; i < length; i++) {
int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx);
vst1q_s8(dst + i * simd_len, result);
}
}
/**
origin src shape <n, ic/4, h, w, 4>
packed src shape <n, ic/4, h, w, 16>
example: (format like <ic>)
origin
<0> <0> <0> <0>
packed
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3>
---------------------------------------------------------------------
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0>
**/
void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) {
static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3,
3, 2, 1, 0, 3, 2, 1, 0};
constexpr int pack_ic = 4;
constexpr int simd_len = 16;
uint8x16_t src_idx = vld1q_u8(src_idx_buffer);
for (int i = 0; i < length; i++) {
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx);
vst1q_s8(dst + i * simd_len, result);
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride1_2x2_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * ic_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride1_3x3_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 3;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride1_5x5_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 5;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride1_7x7_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 7;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
#define INSTANTIATION(stride, i, bias, remain_w, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \
int32_t*, int8_t*, const size_t, const size_t, \
const size_t, const size_t, const size_t, \
const size_t, const Op&);
#define FOR_OP(stride, i, bias, remain_w) \
INSTANTIATION(stride, i, bias, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_REMAIN(stride, i, bias) \
FOR_OP(stride, i, bias, 0) \
FOR_OP(stride, i, bias, 1) \
FOR_OP(stride, i, bias, 2) \
FOR_OP(stride, i, bias, 3) \
FOR_OP(stride, i, bias, 4) \
FOR_OP(stride, i, bias, 5) \
FOR_OP(stride, i, bias, 6) \
FOR_OP(stride, i, bias, 7)
#define FOR_BIAS(stride, i) \
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(stride1)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace {
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];
init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]);
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]);
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]);
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]);
src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]);
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16));
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size>
static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
int8_t* dst_ptr, int ic, int ih,
int iw, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
int32x4_t c[2 * 4];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_oc4_ow8<bias_mode>(c, bias_ptr);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);
// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]);
c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]);
c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]);
c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]);
c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]);
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16);
c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]);
c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]);
c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]);
src[7] = vld1q_s8(src_ic_0_3 + 17 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16);
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16);
c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]);
c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]);
c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr);
}
} // namespace
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride2_2x2_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * ic_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride2_3x3_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 3;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride2_5x5_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 5;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int remain_w>
void conv_bias::conv_direct_stride2_7x7_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 7;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, 0, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, remain_w,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, op);
}
}
}
}
#define INSTANTIATION(stride, i, bias, remain_w, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \
int32_t*, int8_t*, const size_t, const size_t, \
const size_t, const size_t, const size_t, \
const size_t, const Op&);
#define FOR_OP(stride, i, bias, remain_w) \
INSTANTIATION(stride, i, bias, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_REMAIN(stride, i, bias) \
FOR_OP(stride, i, bias, 0) \
FOR_OP(stride, i, bias, 1) \
FOR_OP(stride, i, bias, 2) \
FOR_OP(stride, i, bias, 3) \
FOR_OP(stride, i, bias, 4) \
FOR_OP(stride, i, bias, 5) \
FOR_OP(stride, i, bias, 6) \
FOR_OP(stride, i, bias, 7)
#define FOR_BIAS(stride, i) \
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(stride2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
......@@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false};
AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44;
AlgoS8DirectNCHW44 s8_direct_nchw44;
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44;
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
......@@ -114,11 +113,10 @@ public:
direct_algos.emplace_back(&qu8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride2_large_group);
direct_algos.emplace_back(&s8_direct_stride2_small_group);
direct_algos.emplace_back(&s8_direct_stride2_nchw44);
direct_algos.emplace_back(&s8_direct_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1_large_group);
direct_algos.emplace_back(&s8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
......
......@@ -37,9 +37,8 @@ protected:
private:
class AlgoS8DirectStride1;
class AlgoS8DirectStride1NCHW44;
class AlgoS8DirectStride2;
class AlgoS8DirectStride2NCHW44;
class AlgoS8DirectNCHW44;
class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2;
......
......@@ -27,6 +27,8 @@ struct NoneOp;
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \
template <> \
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \
NoneOp(){}; \
NoneOp(float, float){}; \
using NoneOpBase::NoneOpBase; \
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
......
......@@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
run(1, 3, 32, 224, 224, 5, 1, true);
run(1, 3, 64, 224, 224, 7, 1, true);
for (size_t stride : {1, 2}) {
run(1, 64, 128, 56, 56, 3, 2, false);
run(1, 128, 256, 28, 28, 3, 2, false);
run(1, 256, 512, 14, 14, 3, 2, false);
run(1, 128, 128, 28, 28, 3, 1, false);
run(1, 256, 256, 14, 14, 3, 1, false);
run(1, 512, 512, 7, 7, 3, 1, false);
for (size_t stride : {1}) {
printf("stride %zu\n", stride);
for (size_t filter_size : {2, 3, 5, 7}) {
for (size_t img_size : {32}) {
......
......@@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
handle(), "S8_NCHW44_DIRECT_STRD1");
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
checker_conv_bias_qint8x8x32(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true),
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
checker_conv_bias_qint8x8x32(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "S8_NCHW44_DIRECT_STRD2");
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
checker_conv_bias_qint8x8x8(
......@@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
using namespace conv_bias;
......@@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
param::MatrixMul::Format format, float eps) {
for (auto&& arg : args) {
for (uint32_t m : out_size) {
checker.set_extra_opr_impl(std::bind(
winograd_algo_extra_impl, std::placeholders::_1, m,
arg.param, handle, format));
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
checker.set_extra_opr_impl(std::bind(
winograd_algo_extra_impl, std::placeholders::_1, m,
arg.param, handle, format));
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
}
};
......@@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
std::vector<TestArg> quantized_args = get_int8_nchw44_args (3,4);
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
UniformIntRNG int_rng{-50, 50};
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f),
......@@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle());
......@@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM
param::MatrixMul::Format format, float eps) {
for (auto&& arg : args) {
for (uint32_t m : out_size) {
checker.set_extra_opr_impl(std::bind(
winograd_algo_extra_impl, std::placeholders::_1, m,
arg.param, handle, format));
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
checker.set_extra_opr_impl(std::bind(
winograd_algo_extra_impl, std::placeholders::_1, m,
arg.param, handle, format));
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
}
};
......@@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM
dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle());
......@@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
#if MEGDNN_AARCH64
const char* matmul_name = "AARCH64_F32_MK4_4x16";
#else
const char* matmul_name = "ARMV7_F32_MK4_4x8";
const char* matmul_name = "ARMV7_F32_MK4_4x8";
#endif
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
std::vector<TestArg> quantized_args =
get_int8_nchw44_args(3, 4, true);
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
UniformIntRNG int_rng{-50, 50};
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
dtype::QuantizedS8(0.01887994f),
dtype::QuantizedS32(0.41113496f * 0.01887994f),
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon);
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
epsilon);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle());
......@@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
#if MEGDNN_AARCH64
const char* matmul_name = "AARCH64_F32_MK4_4x16";
#else
const char* matmul_name = "ARMV7_F32_MK4_4x8";
const char* matmul_name = "ARMV7_F32_MK4_4x8";
#endif
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
......@@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f),
dtype::QuantizedS8(0.01887994f),
dtype::QuantizedS32(0.41113496f * 0.01887994f),
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon);
dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4,
epsilon);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
using namespace conv_bias;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册