提交 7b0dbe6a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix stride 1 support for int8 nchw_nchw44

GitOrigin-RevId: 9d718eb7a4dae3c2724ea07ba2b639fbfb319f78
上级 198f3eb5
......@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
constexpr int cacheline = 64 / sizeof(float);
constexpr int nr_elements_in_cacheline = 64 / sizeof(float);
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh = param.osz[0];
......@@ -52,7 +52,8 @@ static void get_rectified_size(
int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline);
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]),
nr_elements_in_cacheline);
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
......
......@@ -90,9 +90,9 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
public:
AlgoS8DirectStride2NCHWNCHW44() {}
AlgoS8DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -12,7 +12,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
......@@ -25,93 +25,147 @@ using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2)
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44)
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
auto&& fm = param.filter_meta;
size_t IH = param.isz[0];
size_t IW = param.isz[1];
size_t OH = param.osz[0];
size_t OW = param.osz[1];
int ih = param.isz[0];
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
int ph = fm.padding[0];
int pw = fm.padding[1];
int stride_h = fm.stride[0];
OH2 = OH;
OW2 = OW;
IH2 = round_up(IH + 2 * fm.padding[0], static_cast<size_t>(2));
IW2 = IW + 2 * fm.padding[1];
oh2 = oh;
ow2 = ow;
ih2 = stride_h == 2 ? round_up(ih + 2 * ph, 2) : ih + 2 * ph;
iw2 = iw + 2 * pw;
}
static inline size_t get_temp_bytes(const int iw, const int pw) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 1 * cacheline_size;
return round_up(iw + pw * 2, cacheline_size) + border_size;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
constexpr size_t src_expand = 4;
auto&& fm = param.filter_meta;
size_t group = fm.group;
size_t batch = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(param, IH2, IW2, OH2, OW2);
int group = fm.group;
int batch = param.n;
int ic = fm.icpg;
int oc = fm.ocpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int stride_h = fm.stride[0];
int iw = param.isz[1];
int pw = fm.padding[1];
int ih2, iw2, oh2, ow2;
const size_t src_expand = stride_h == 2 ? 4 : 16;
get_rectified_size(param, ih2, iw2, oh2, ow2);
megdnn_assert(group == 1, "only support group == 1 now");
size_t src_size =
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t);
return {nullptr, {src_size, weight_size}};
batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t);
size_t tmp_size = 0;
if (stride_h == 1) {
weight_size = group * oc * ic * fh * round_up(fw, 4) * sizeof(int8_t);
tmp_size = get_temp_bytes(iw, pw);
}
return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}};
};
static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;
int ih = kern_param.isz[0];
int iw = kern_param.isz[1];
int ic = kern_param.filter_meta.icpg;
int ph = kern_param.filter_meta.padding[0];
int pw = kern_param.filter_meta.padding[1];
int group = kern_param.filter_meta.group;
int stride_h = kern_param.filter_meta.stride[0];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
size_t padding_group_size = IH2 * IW2 * IC;
int ih2, iw2, oh2, ow2;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
int padding_group_size = ih2 * iw2 * ic;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
constexpr int expend_element = 4;
// TODO: block dim is better to get from arg
size_t workspace_ic_block = 1;
size_t workspace_batch_id = workspace_ids[0];
size_t workspace_group_id = workspace_ids[1];
size_t workspace_ic_id = workspace_ids[2];
size_t workspace_ic = workspace_ic_id * workspace_ic_block;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
const int src_expand = stride_h == 2 ? 4 : 16;
//! TODO: block dim is better to get from arg
int workspace_ic_block = 1;
int workspace_batch_id = workspace_ids[0];
int workspace_group_id = workspace_ids[1];
int workspace_ic_id = workspace_ids[2];
int workspace_ic = workspace_ic_id * workspace_ic_block;
int batch_id = ncb_index.ndrange_id[0];
int group_id = ncb_index.ndrange_id[1];
const int8_t* sptr = static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1));
//! copy to sptr_base to eliminate padding effect
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * GROUP * padding_group_size +
(workspace_batch_id * group * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * IH2 * IW2) *
expend_element;
conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW,
IH, IW);
workspace_ic * ih2 * iw2) *
src_expand;
if (stride_h == 1) {
const size_t tmp_size = get_temp_bytes(iw, pw);
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, tmp_ptr);
} else {
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, nullptr);
}
}
static void pack_weight(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
bundle.set(kern_param.workspace_ptr);
const int group_id = ncb_index.ndrange_id[0];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int oc = kern_param.filter_meta.ocpg;
int ic = kern_param.filter_meta.icpg;
int stride_h = kern_param.filter_meta.stride[0];
int fw2 = stride_h == 2 ? fw : round_up(fw, 4);
int oc_block = oc;
int oc_idx = 0;
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2;
template <size_t filter, BiasMode bias_mode, typename Op>
if (stride_h == 1) {
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw,
oc_block);
} else {
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw,
oc_block);
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
int oh = kern_param.osz[0];
int ow = kern_param.osz[1];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int fw2 = stride == 2 ? fw : round_up(fw, 4);
int ic = kern_param.filter_meta.icpg;
int oc = kern_param.filter_meta.ocpg;
int group = kern_param.filter_meta.group;
int ih2, iw2, oh2, ow2;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
......@@ -122,54 +176,46 @@ static void do_conv_kern(WorkspaceBundle bundle,
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
size_t padding_group_size = IH2 * IW2 * IC;
int padding_group_size = ih2 * iw2 * ic;
bundle.set(kern_param.workspace_ptr);
constexpr size_t pack_c = 4;
constexpr size_t src_expand_size = 4;
const size_t workspace_batch_id = workspace_ids[0];
const size_t workspace_group_id = workspace_ids[1];
const size_t batch_id = ncb_index.ndrange_id[0];
const size_t group_id = ncb_index.ndrange_id[1];
const size_t oc_id = ncb_index.ndrange_id[2];
const size_t oc_block_num = ncb_range[2];
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num);
size_t oc_block = nr_pack_per_step * pack_c;
const size_t oc_idx = oc_id * oc_block;
constexpr int pack_c = 4;
constexpr int src_expand_size = stride == 2 ? 4 : 16;
const int workspace_batch_id = workspace_ids[0];
const int workspace_group_id = workspace_ids[1];
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
const int oc_id = ncb_index.ndrange_id[2];
const int oc_block_num = ncb_range[2];
int nr_pack_per_step = div_ceil(div_ceil(oc, pack_c), oc_block_num);
int oc_block = nr_pack_per_step * pack_c;
const int oc_idx = oc_id * oc_block;
if (oc_id == (oc_block_num - 1)) {
oc_block = OC - oc_id * nr_pack_per_step * pack_c;
oc_block = oc - oc_id * nr_pack_per_step * pack_c;
}
megdnn_assert(oc_block % pack_c == 0,
"oc must be devisible by 4, but oc = %zu", oc_block);
"oc must be devisible by 4, but oc = %d", oc_block);
const int8_t* sptr =
static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * GROUP * padding_group_size * src_expand_size +
workspace_batch_id * group * padding_group_size * src_expand_size +
workspace_group_id * padding_group_size * src_expand_size;
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC;
void* dst = reinterpret_cast<void*>(
int8_t* dst = reinterpret_cast<int8_t*>(
reinterpret_cast<ptrdiff_t>(
kern_param.dst<void>(batch_id, group_id)) +
oc_idx * OH * OW);
oc_idx * oh * ow);
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW,
oc_block);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, IH2, IW2, \
OH, OW, op)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 +
oc_idx * ic * fh * fw2;
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh,
ow, op);
}
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable(
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
......@@ -184,13 +230,14 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable(
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) &&
fm.group == 1 && param.bias_mode != BiasMode::BIAS;
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 &&
param.bias_mode != BiasMode::BIAS;
return avaible;
}
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred(
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr,
const NCBKernSizeParam& param) const {
// TODO: benchmark and fix
......@@ -199,13 +246,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred(
return false;
}
size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace(
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
......@@ -215,61 +262,76 @@ ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \
} \
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode) \
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv",
param.filter_meta.stride[0])
.c_str());
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
......@@ -290,6 +352,12 @@ ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
};
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}});
auto do_pack_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
pack_weight(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}});
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -9,28 +9,40 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h"
#pragma once
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace megdnn {
namespace arm_common {
namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, typename T,
typename T2, typename T3, typename T4>
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight, T4& temp);
static void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> {
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]);
......@@ -62,7 +74,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> {
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> {
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]);
......@@ -81,17 +93,81 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> {
}
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T,
typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[1][0] = Func::impl(src[(0 + src_idx) % 8], weight[1][weight_idx],
c[1][0], temp[1]);
c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[2]);
c[1][1] = Func::impl(src[(1 + src_idx) % 8], weight[1][weight_idx],
c[1][1], temp[3]);
c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[0]);
c[1][2] = Func::impl(src[(2 + src_idx) % 8], weight[1][weight_idx],
c[1][2], temp[1]);
c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[2]);
c[1][3] = Func::impl(src[(3 + src_idx) % 8], weight[1][weight_idx],
c[1][3], temp[3]);
c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[1][4] = Func::impl(src[(4 + src_idx) % 8], weight[1][weight_idx],
c[1][4], temp[1]);
c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[2]);
c[1][5] = Func::impl(src[(5 + src_idx) % 8], weight[1][weight_idx],
c[1][5], temp[3]);
c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[0]);
c[1][6] = Func::impl(src[(6 + src_idx) % 8], weight[1][weight_idx],
c[1][6], temp[1]);
c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[2]);
c[1][7] = Func::impl(src[(7 + src_idx) % 8], weight[1][weight_idx],
c[1][7], temp[3]);
}
static void impl(T&, T2&, T3&);
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[0][1] = Func::impl(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[1]);
c[0][2] = Func::impl(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[2]);
c[0][3] = Func::impl(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[3]);
c[0][4] = Func::impl(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[0][5] = Func::impl(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[1]);
c[0][6] = Func::impl(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[2]);
c[0][7] = Func::impl(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[3]);
}
static void impl(T&, T2&, T3&);
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3, typename T4>
inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, T4>::impl(
c, src, weight, temp);
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3,
T4>::impl(c, src, weight, temp);
}
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T,
typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl(
c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3,
int>::impl(c, src, weight);
};
template <int oc>
......@@ -111,7 +187,7 @@ public:
};
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block>
int oc_block, int stride>
struct KerNeonXXs2NchwNchw44 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
......@@ -143,8 +219,9 @@ struct KerNeonXXs2NchwNchw44 {
* |x x|x x|x x|x|
* |---|---|---|-|
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
......@@ -176,12 +253,12 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
cal_helper<2, 2, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
......@@ -189,8 +266,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
dot2_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(
c, src_dot2, dot2_weight, temp_c);
weight_ptr += filter_size * pack_iw_len * fh_step;
}
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride +
......@@ -204,12 +281,12 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
ld_dot4_weight_oc);
load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr,
0, tbl);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2,
dot2_weight, temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2,
dot2_weight, temp_c);
cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>(c, src_dot2,
dot2_weight, temp_c);
int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
......@@ -217,14 +294,16 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
dot1_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1,
dot1_weight);
weight_ptr += filter_size * pack_iw_len;
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
......@@ -255,10 +334,10 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
......@@ -266,8 +345,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
dot2_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(
c, src_dot2, dot2_weight, temp_c);
weight_ptr += filter_size * pack_iw_len * ih_step;
}
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride +
......@@ -282,10 +361,10 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr,
0, tbl);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2,
dot2_weight, temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2,
dot2_weight, temp_c);
int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
......@@ -294,7 +373,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1,
dot1_weight);
weight_ptr += filter_size * pack_iw_len;
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
......@@ -315,8 +395,9 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
* |x x|x|
* |-----|
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
......@@ -345,8 +426,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_weight_oc);
load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(
c, src, dot4_weight, temp_c);
int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
......@@ -354,8 +435,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
dot2_weight, weight_ptr, ld_weight_oc);
load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(
c, src_dot2, dot2_weight, temp_c);
}
// last line
{
......@@ -369,23 +450,257 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
ld_weight_oc);
load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>(
src_dot2, nchw_src_ptr, 0, tbl);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(
c, src_dot2, dot2_weight, temp_c);
int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>(
dot1_weight, weight_ptr, ld_weight_oc);
load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1,
dot1_weight);
weight_ptr += filter_size * filter_size * pack_iw_len;
}
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int pack_iw_len = 4;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][4];
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[4];
int8x16_t dot4_weight[c_dim][1];
int16x8_t temp_c[4];
load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_weight_oc);
load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
weight_ptr += oc_step * filter_size * filter_size;
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 2;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 3;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 2 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight,
temp_c);
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 5;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(5, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 7;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(7, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
} // namespace
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 };
template <PACK_MODE mode>
inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad,
......@@ -443,14 +758,24 @@ inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad,
memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t));
outptr += combine_row * right_pad * src_expand;
}
template <int stride>
void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int top_pad,
const int bottom_pad, const int left_pad,
const int right_pad, const int ih,
const int iw, const int iw2, const int pw,
int8_t* temp_ptr);
/**
* pack (ic, h, w) to (ic, h / 2, 2 * w)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
void conv_bias::pack_nchw_src_for_nchw44_conv(
const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad,
const int bottom_pad, const int left_pad, const int right_pad,
const int ih, const int iw) {
template <>
void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr,
const int ic, const int top_pad,
const int bottom_pad, const int left_pad,
const int right_pad, const int ih,
const int iw, const int, const int,
int8_t*) {
constexpr int src_expand = 4;
constexpr int oh_step = 2;
const int oh = ih + top_pad + bottom_pad;
......@@ -490,15 +815,75 @@ void conv_bias::pack_nchw_src_for_nchw44_conv(
}
}
}
/**
* pack (ic, h, w) to (ic, h, w * 16)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
template <>
void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin,
int8_t* sptr_base, const int ic,
const int pad_top, const int pad_bottom,
const int, const int, const int ih,
const int iw, const int iw2, const int pw,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);
constexpr int iw_step = 4;
constexpr int pack_iw_len = 16;
const int ic_stride = ih * iw;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 1);
src[2] = vld1q_s8(temp_ptr + iw_idx + 2);
src[3] = vld1q_s8(temp_ptr + iw_idx + 3);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
int8x16_t src = vld1q_s8(temp_ptr + iw_idx);
int8x16_t dst = vqtbl1q_s8(src, tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}
template <int stride>
void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int fh, const int fw,
const int oc);
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)}
* pack interleave two adjacent row in filter to one row
* */
void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr,
int8_t* outptr, const int ic,
const int fh, const int fw,
const int oc) {
template <>
void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr,
const int ic, const int fh,
const int fw, const int oc) {
constexpr int oc_step = 4;
constexpr int ic_step = 2;
constexpr int fh_step = 2;
......@@ -610,180 +995,293 @@ void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr,
outptr += oc_step * fh * fw * ic;
}
}
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
* pack interleave two adjacent row in filter to one row
* */
template <>
void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr,
const int ic, const int fh,
const int fw, const int oc) {
constexpr int oc_step = 4;
const int fw2 = round_up(fw, 4);
const int fw_remain = fw2 - fw;
const int dst_ic_stride = fh * fw2;
const int oc_step_stride = fh * fw2 * ic * oc_step;
static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7,
8, 12, 9, 13, 10, 14, 11, 15};
uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]);
rep_step(oc_idx, oc, oc_step) {
int32_t* dst_temp_ptr =
reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2);
const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>(
src_ptr + oc_idx * ic * fh * fw);
// transpose ic and pad
rep(fh_idx, fh) {
rep(fw_idx, fw) {
rep(ic_idx, ic) {
*(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr;
src_temp_ptr++;
}
dst_temp_ptr++;
}
rep(ic_idx, ic) {
memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0,
sizeof(int8_t) * oc_step * fw_remain);
}
dst_temp_ptr += fw_remain;
}
// transpose fw oc
int8_t* trans_dst_temp_ptr =
reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2);
template <BiasMode bias_mode, typename Op, size_t filter_size>
static void conv_direct_stride2_int8_nchw_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 2;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 4;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
rep_step(idx, oc_step_stride, 16) {
int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx);
vst1q_s8(trans_dst_temp_ptr + idx,
vqtbl1q_s8(temp, tbl_transpose_4x4));
}
}
};
template <BiasMode bias_mode, typename Op, size_t filter_size, int stride>
struct ConvDiectStrideInt8NchwNchw44 {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, int8_t* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw =
stride == 2 ? filter_size : (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = stride == 2 ? 2 : 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = stride == 2 ? 4 : 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step>::impl; \
big_oc_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step>::impl; \
oc_step, stride>::impl; \
break;
UNROLL_CALL_RAW(4, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
UNROLL_CALL_RAW(4, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
big_oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
big_oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset,
filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
}
if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
};
template <BiasMode bias_mode, typename Op, size_t filter_size>
struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, int8_t* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int stride = 1;
constexpr size_t fh = filter_size;
constexpr size_t fw = (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 16;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step, stride>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
big_oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset,
filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \
const int8_t* src, const int8_t* filter, \
const int32_t* bias, int32_t* temp, int8_t* dst, \
const size_t oc, const size_t ic, const size_t ih, \
const size_t iw, const size_t oh, const size_t ow, \
const Op& op) { \
conv_direct_stride2_int8_nchw_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \
}
};
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC
template <BiasMode bias_mode, typename Op>
void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
MEGDNN_MARK_USED_VAR(bias);
MEGDNN_MARK_USED_VAR(temp);
MEGDNN_MARK_USED_VAR(dst);
MEGDNN_MARK_USED_VAR(oc);
MEGDNN_MARK_USED_VAR(ic);
MEGDNN_MARK_USED_VAR(ih);
MEGDNN_MARK_USED_VAR(iw);
MEGDNN_MARK_USED_VAR(oh);
MEGDNN_MARK_USED_VAR(ow);
MEGDNN_MARK_USED_VAR(op);
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv");
template <BiasMode bias_mode, typename Op, size_t filter_size, int stride>
static void conv_direct_int8_nchw_nchw44(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, stride>::impl(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44<bias, Op>( \
const int8_t*, const int8_t*, const int32_t*, int32_t*, \
int8_t*, const size_t, const size_t, const size_t, \
const size_t, const size_t, const size_t, const Op&);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(stride2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int fh, const int fw,
const int oc);
void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int top_pad,
const int bottom_pad, const int left_pad,
const int right_pad, const int ih,
const int iw);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
\ No newline at end of file
......@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44;
AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44;
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44;
......@@ -115,7 +115,7 @@ public:
direct_algos.emplace_back(&s8_direct_stride2_large_group);
direct_algos.emplace_back(&s8_direct_stride2_small_group);
direct_algos.emplace_back(&s8_direct_stride2_nchw44);
direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1_large_group);
direct_algos.emplace_back(&s8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride1_nchw44);
......
......@@ -40,7 +40,7 @@ private:
class AlgoS8DirectStride1NCHW44;
class AlgoS8DirectStride2;
class AlgoS8DirectStride2NCHW44;
class AlgoS8DirectStride2NCHWNCHW44;
class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2;
class AlgoFP32WinogradF23_4x4;
......
......@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif
}
......
......@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册