提交 03808112 编写于 作者: M Megvii Engine Team

feat(dnn/arm_common): add nchw44 8x8x16 stride1 stride2

                    2x2 3x3 5x5 7x7 directconv

GitOrigin-RevId: 3710182af1974775c0960a4ebac3c7cc7e3d93d5
上级 2dbe8194
......@@ -38,6 +38,18 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
public:
AlgoS8x8x16DirectNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
......
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/conv_direct_int8x8x16_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct)
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2) {
auto&& fm = param.filter_meta;
int ih = param.isz[0];
int iw = param.isz[1];
int ph = fm.padding[0];
int pw = fm.padding[1];
ih2 = ih + ph * 2;
iw2 = iw + pw * 2;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
size_t group = fm.group;
size_t batch = param.n;
size_t IC = fm.icpg;
int IH2, IW2;
get_rectified_size(param, IH2, IW2);
if (group == 1) {
size_t src_size = 0;
bool need_padding = param.filter_meta.padding[0] > 0 ||
param.filter_meta.padding[1] > 0;
src_size = need_padding
? batch * group * IC * IH2 * IW2 * sizeof(int8_t)
: 0;
#if MEGDNN_ARMV7
if (fm.stride[0] == 1) {
constexpr int src_expand_element = 4;
src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) *
src_expand_element;
}
#endif
return {nullptr, {src_size}};
} else {
size_t src_size = 0;
bool need_padding = param.filter_meta.padding[0] > 0 ||
param.filter_meta.padding[1] > 0;
src_size = need_padding
? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t)
: 0;
#if MEGDNN_ARMV7
if (fm.stride[0] == 1) {
constexpr int src_expand_element = 4;
src_size = param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) *
src_expand_element;
}
#endif
return {nullptr, {src_size}};
}
};
#if MEGDNN_ARMV7
static void copy_padding_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
int IH = kern_param.isz[0];
int IW = kern_param.isz[1];
int IC = kern_param.filter_meta.icpg;
int PH = kern_param.filter_meta.padding[0];
int PW = kern_param.filter_meta.padding[1];
int GROUP = kern_param.filter_meta.group;
int IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
int padding_group_size = IH2 * IW2 * IC;
//! Used for get the workspace offset
constexpr int pack_ic = 4;
constexpr int src_expand_element = 4;;
size_t workspace_ic_block = 4;
size_t workspace_batch_id = workspace_ids[0];
size_t workspace_group_id = workspace_ids[1];
size_t workspace_ic_id = workspace_ids[2];
size_t workspace_ic = workspace_ic_id * workspace_ic_block;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t group_pack_size = 1;
int nr_pad_w = PW * pack_ic * src_expand_element;
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element;
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element;
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element;
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>(
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic));
//! copy to sptr_base to eliminate padding effect
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * GROUP * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * IH2 * IW2) *
src_expand_element;
size_t nr_ic = workspace_ic_block;
if (GROUP > 1) {
nr_ic = IC;
}
rep_step(ic_idx, nr_ic, pack_ic) {
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t));
sptr_base += nr_pad_h;
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
int8x8x16_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * src_expand_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
sptr_base += row_last_pad;
}
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t));
sptr_base += col_last_pad;
}
}
#endif
static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
int IH = kern_param.isz[0];
int IW = kern_param.isz[1];
int IC = kern_param.filter_meta.icpg;
int PH = kern_param.filter_meta.padding[0];
int PW = kern_param.filter_meta.padding[1];
int GROUP = kern_param.filter_meta.group;
int IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
int padding_group_size = IH2 * IW2 * IC;
//! Used for get the workspace offset
constexpr int pack_ic = 4;
constexpr int src_expand_element = 1;
size_t workspace_ic_block = 4;
size_t workspace_batch_id = workspace_ids[0];
size_t workspace_group_id = workspace_ids[1];
size_t workspace_ic_id = workspace_ids[2];
size_t workspace_ic = workspace_ic_id * workspace_ic_block;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t group_pack_size = 1;
int nr_pad_w = PW * pack_ic * src_expand_element;
int nr_pad_h = PH * IW2 * pack_ic * src_expand_element;
int row_last_pad = (IW2 - IW - PW) * pack_ic * src_expand_element;
int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * src_expand_element;
const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>(
batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic));
//! copy to sptr_base to eliminate padding effect
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * GROUP * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * IH2 * IW2) *
src_expand_element;
size_t nr_ic = workspace_ic_block;
if (GROUP > 1) {
nr_ic = IC;
}
rep_step(ic_idx, nr_ic, pack_ic) {
std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t));
sptr_base += nr_pad_h;
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
std::memcpy(sptr_base, sptr, IW * pack_ic);
sptr_base += IW * pack_ic * src_expand_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
sptr_base += row_last_pad;
}
std::memset(sptr_base, 0, col_last_pad * sizeof(int8_t));
sptr_base += col_last_pad;
}
}
template <size_t filter, BiasMode bias_mode, int stride>
static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t GROUP = kern_param.filter_meta.group;
int IH2, IW2;
get_rectified_size(kern_param, IH2, IW2);
size_t padding_group_size = IH2 * IW2 * IC;
constexpr size_t pack_c = 4;
const size_t workspace_batch_id = workspace_ids[0];
const size_t workspace_group_id = workspace_ids[1];
const size_t batch_id = ncb_index.ndrange_id[0];
const size_t group_id = ncb_index.ndrange_id[1];
const size_t oc_id = ncb_index.ndrange_id[2];
const size_t oc_block_num = ncb_range[2];
megdnn_assert((OC & (pack_c - 1)) == 0, "OC must times of 4");
size_t nr_pack_per_step = div_ceil(OC / pack_c, oc_block_num);
size_t oc_block = nr_pack_per_step * pack_c;
const size_t oc_idx = oc_id * oc_block;
if (oc_id == (oc_block_num - 1)) {
oc_block = OC - oc_id * nr_pack_per_step * pack_c;
}
megdnn_assert(oc_block % pack_c == 0,
"oc must be devisible by 4, but oc = %zu", oc_block);
bool need_padding = kern_param.filter_meta.padding[0] > 0 ||
kern_param.filter_meta.padding[1] > 0;
const int8_t* sptr = need_padding
? static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * GROUP * padding_group_size +
workspace_group_id * padding_group_size
: kern_param.src<int8_t>(batch_id, group_id);
//!armv7 use packsrc mode
#if MEGDNN_ARMV7
if (stride == 1) {
constexpr size_t src_expand_size = 4;
sptr = static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * GROUP * padding_group_size *
src_expand_size +
workspace_group_id * padding_group_size * src_expand_size;
}
#endif
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC;
int16_t* dst = reinterpret_cast<int16_t*>(
kern_param.dst<void>(batch_id, group_id, oc_idx));
const int16_t* bptr =
kern_param.bias<dt_int16>(batch_id, group_id) + oc_idx;
int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose<
bias_mode, filter, stride>::impl(sptr, fptr, bptr, dst, oc_block,
IC, IH2, IW2, OH, OW);
}
bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
const int fh = fm.spatial[0];
const int fw = fm.spatial[1];
const int oc = fm.ocpg;
const int ic = fm.icpg;
const bool avaible = //! src and filter are int8, dst is int16_t
(param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7) &&
param.nonlineMode == NonlineMode::IDENTITY &&
param.bias_mode != BiasMode::BIAS;
return avaible;
}
size_t ConvBiasImpl::AlgoS8x8x16DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t group = fm.group;
size_t fh = fm.spatial[0];
size_t fw = fm.spatial[1];
size_t ph = fm.padding[0];
size_t pw = fm.padding[1];
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \
midout_iv("int8x8x16_nchw44_direct_" \
"conv" #stride #filter #bias_mode##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \
break; \
default: \
megdnn_throw(ssprintf("only support IDENTITY mode when dst is " \
"dt_int16 nonlineMode is %d", \
uint32_t(param.nonlineMode)) \
.c_str()); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_throw(ssprintf("only support NO_BIAS/BROADCAST biasmode " \
"when dst is " \
"dt_int16 biasmode is %d", \
uint32_t(param.bias_mode)) \
.c_str()); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_throw(ssprintf("only support 2x2 3x3 5x5 7x7 filters size " \
"when dst is " \
"dt_int16 filter size is %u", \
uint32_t(param.filter_meta.spatial[0])) \
.c_str()); \
break; \
}
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the 8x8x16 direct conv",
param.filter_meta.stride[0])
.c_str());
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
constexpr size_t pack_oc = 4;
size_t oc_step = pack_oc;
if (fh == fw && (fh == 2 || fw == 3) && OC >= 8) {
oc_step = 8;
}
#if MEGDNN_ARMV7
if (param.filter_meta.stride[0] == 1) {
if (group == 1) {
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto copy_padding = [wbundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
copy_padding_kern(wbundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
constexpr size_t pack_ic = 4;
ret_kerns.push_back(
{copy_padding, {N, group, div_ceil(IC, pack_ic)}});
auto do_conv = [wbundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
do_conv_fun(wbundle, kern_param, ncb_index,
ncb_index.ndrange_id, ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
} else {
CpuNDRange ncb_range = {N, group, 1};
auto do_conv = [wbundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
copy_padding_kern(wbundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0});
do_conv_fun(wbundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0}, ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
}
return ret_kerns;
}
#endif
bool need_padding = ph > 0 || pw >0;
if (group == 1) {
CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto copy_padding = [wbundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
constexpr size_t pack_ic = 4;
if (need_padding) {
ret_kerns.push_back(
{copy_padding, {N, group, div_ceil(IC, pack_ic)}});
}
auto do_conv = [wbundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
} else {
CpuNDRange ncb_range = {N, group, 1};
auto do_conv = [wbundle, do_conv_fun, ncb_range, need_padding](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
wbundle.set(kern_param.workspace_ptr);
if (need_padding) {
copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0});
};
do_conv_fun(wbundle, kern_param, ncb_index,
{0, ncb_index.thread_id, 0}, ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
}
return ret_kerns;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace int8x8x16_direct_nchw44 {
/**
origin src shape <n, ic/4, h, w, 4>
packed src shape <n, ic/4, h, w, 16>
example: (format like <ic>)
origin
<0> <1> <2> <3>
packed
low 64 bit <0> <0> <0> <0> | <1> <1> <1> <1>
---------------------------------------------------------------------
high 64 bit <2> <2> <2> <2> | <3> <3> <3> <3>
**/
static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) {
static const uint8_t src_idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
constexpr int pack_ic = 4;
constexpr int simd_len = 16;
uint8x16_t src_idx = vld1q_u8(src_idx_buffer);
for (int i = 0; i < length; i++) {
int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx);
vst1q_s8(dst + i * simd_len, result);
}
}
template <BiasMode bias_mode, int filter_size, int stride>
struct ConvDirectInt8Nchw44Choose {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow);
};
} // namespace int8_direct_nchw44
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_aarch64.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/simd_macro/marm_neon.h"
#if MEGDNN_AARCH64
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
#define INIT_SUM() \
int16x4_t init_sum; \
if (bias_mode == BiasMode::NO_BIAS) { \
init_sum = vdup_n_s16(0); \
} else { \
init_sum = vld1_s16(bias_ptr); \
}
#define STORE_1_LINE_RESULT() \
switch (remain_w) { \
case 8: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
break; \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
break; \
case 5: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1_s16(dst_ptr + 16, c[0][4]); \
break; \
case 6: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
break; \
case 7: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + 24, c[0][6]); \
break; \
default: \
megdnn_assert(0, "oc 1 error remainw"); \
};
#define STORE_2_LINE_RESULT_OW4() \
switch (remain_w) { \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \
break; \
default: \
megdnn_assert(0, "oc 2 error remainw"); \
break; \
}
#define STORE_1_LINE_RESULT_OW4_OH2() \
switch (remain_w) { \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 8 + ow, vcombine_s16(c[0][6], c[0][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
vst1_s16(dst_ptr + ow, c[0][4]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
vst1q_s16(dst_ptr + ow, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + ow + 8, c[0][6]); \
break; \
default: \
megdnn_assert(0, "oc 2 error remainw"); \
break; \
}
#define STORE_1_LINE_RESULT_OW4() \
switch (remain_w) { \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
break; \
default: \
megdnn_assert(0, "oc 1 error remainw"); \
};
template <BiasMode bias_mode,int filter_size>
static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w,int ld_dst_oc) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int16x4_t c[2][4];
int8x16_t weight[2][2];
int8x16_t src[5];
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
INIT_SUM();
#define cb(_i) \
c[0][_i] = init_sum; \
c[1][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx);
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx);
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx);
weight[0][0] = vld1q_s8(weight_ptr);
weight[0][1] = vld1q_s8(weight_ptr + 16);
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16);
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]);
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx);
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]);
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]);
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]);
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx);
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx);
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx);
weight[0][0] = vld1q_s8(weight_ptr + 32);
weight[0][1] = vld1q_s8(weight_ptr + 48);
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32);
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]);
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx);
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]);
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]);
CALC_ONE_RESULT(src[3], src[4], weight[1][0], weight[1][1], c[1][3]);
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_2_LINE_RESULT_OW4();
}
template <BiasMode bias_mode, int filter_size>
static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w,
int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
const int ic_stride = ih * iw;
int16x4_t c[1][4];
int8x16_t weight[1][2];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][0]);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1],
c[0][1]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][2]);
CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1],
c[0][3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT_OW4();
}
#undef CALC_ONE_RESULT
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \
do { \
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
template <BiasMode bias_mode, int filter_size>
static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w,
int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c[1][4];
int8x16_t weight[1][3];
int8x16_t src[6];
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_row0 + 0, idx);
src[1] = vld_dup_tbl_s32(src_row0 + 4, idx);
src[2] = vld_dup_tbl_s32(src_row0 + 8, idx);
weight[0][0] = vld1q_s8(weight_ptr);
weight[0][1] = vld1q_s8(weight_ptr + 16);
weight[0][2] = vld1q_s8(weight_ptr + 32);
src[3] = vld_dup_tbl_s32(src_row0 + 12, idx);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]);
src[4] = vld_dup_tbl_s32(src_row0 + 16, idx);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]);
src[5] = vld_dup_tbl_s32(src_row0 + 20, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]);
src[0] = vld_dup_tbl_s32(src_row1 + 0, idx);
src[1] = vld_dup_tbl_s32(src_row1 + 4, idx);
src[2] = vld_dup_tbl_s32(src_row1 + 8, idx);
weight[0][0] = vld1q_s8(weight_ptr + 48);
weight[0][1] = vld1q_s8(weight_ptr + 64);
weight[0][2] = vld1q_s8(weight_ptr + 80);
src[3] = vld_dup_tbl_s32(src_row1 + 12, idx);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]);
src[4] = vld_dup_tbl_s32(src_row1 + 16, idx);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]);
src[5] = vld_dup_tbl_s32(src_row1 + 20, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]);
src[0] = vld_dup_tbl_s32(src_row2 + 0, idx);
src[1] = vld_dup_tbl_s32(src_row2 + 4, idx);
src[2] = vld_dup_tbl_s32(src_row2 + 8, idx);
weight[0][0] = vld1q_s8(weight_ptr + 96);
weight[0][1] = vld1q_s8(weight_ptr + 112);
weight[0][2] = vld1q_s8(weight_ptr + 128);
src[3] = vld_dup_tbl_s32(src_row2 + 12, idx);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]);
src[4] = vld_dup_tbl_s32(src_row2 + 16, idx);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][1]);
src[5] = vld_dup_tbl_s32(src_row2 + 20, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], weight[0], c[0][3]);
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT_OW4();
}
template <BiasMode bias_mode, int filter_size>
static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic,
int ih, int iw, int remain_w,
int /*ld_dst_oc*/, int ow) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c[1][8];
int8x16_t weight[2][3];
int8x16_t src[1][6];
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
const int8_t* src_row3 =
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step;
#define LOAD_SRC(_src, _src_ptr) \
_src[0] = vld_dup_tbl_s32(_src_ptr + 0, idx); \
_src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \
_src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \
_src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \
_src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \
_src[5] = vld_dup_tbl_s32(_src_ptr + 20, idx);
LOAD_SRC(src[0], src_row0);
weight[0][0] = vld1q_s8(weight_ptr);
weight[0][1] = vld1q_s8(weight_ptr + 16);
weight[0][2] = vld1q_s8(weight_ptr + 32);
weight[1][0] = vld1q_s8(weight_ptr + 48);
weight[1][1] = vld1q_s8(weight_ptr + 64);
weight[1][2] = vld1q_s8(weight_ptr + 80);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0],
c[0][0]); // row0 src0 w0
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]);
LOAD_SRC(src[0], src_row1);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0],
c[0][4]); // row1 src1 w0
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1],
c[0][0]); // row1 src1 w1
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][1]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][2]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][3]);
LOAD_SRC(src[0], src_row2);
weight[0][0] = vld1q_s8(weight_ptr + 96);
weight[0][1] = vld1q_s8(weight_ptr + 112);
weight[0][2] = vld1q_s8(weight_ptr + 128);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1],
c[0][4]); // row2 src0 w1
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][5]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][6]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][7]);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0],
c[0][0]); // row2 w0 src[0]
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]);
LOAD_SRC(src[0], src_row3);
CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0],
c[0][4]); // row3 w0 src1
CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]);
CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]);
CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]);
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT_OW4_OH2();
}
#undef LOAD_SRC
#undef CALC_ONE_RESULT
template <BiasMode bias_mode, int filter_size>
struct KerNeonDirectStride1Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w, int ld_dst_oc);
};
template <BiasMode bias_mode>
struct KerNeonDirectStride1Int8<bias_mode, 5> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w, int /*ld_dst_oc*/) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[1][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx);
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx);
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx);
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx);
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx);
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \
_w4, _c) \
do { \
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \
tmp0 = vaddq_s16(tmp0, tmp1); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][0]);
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][1]);
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][3]);
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][4]);
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][5]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx);
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][6]);
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
};
#undef CALC_ONE_RESULT
template <BiasMode bias_mode>
struct KerNeonDirectStride1Int8<bias_mode, 7> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int remain_w, int /*ld_dst_oc*/) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[1][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 0 * 4, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 1 * 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 2 * 4, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 3 * 4, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 4 * 4, idx);
src[5] = vld_dup_tbl_s32(src_ic_0_3 + 5 * 4, idx);
src[6] = vld_dup_tbl_s32(src_ic_0_3 + 6 * 4, idx);
src[7] = vld_dup_tbl_s32(src_ic_0_3 + 7 * 4, idx);
src[8] = vld_dup_tbl_s32(src_ic_0_3 + 8 * 4, idx);
src[9] = vld_dup_tbl_s32(src_ic_0_3 + 9 * 4, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \
_c) \
do { \
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \
int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \
int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \
tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \
tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \
tmp0 = vaddq_s16(tmp0, tmp1); \
tmp2 = vaddq_s16(tmp2, tmp3); \
tmp0 = vaddq_s16(tmp0, tmp2); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5],
src[6], weight, c[0][0]);
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6],
src[7], weight, c[0][1]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7],
src[8], weight, c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8],
src[9], weight, c[0][3]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 12 * 4, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 13 * 4, idx);
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9],
src[0], weight, c[0][4]);
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0],
src[1], weight, c[0][5]);
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1],
src[2], weight, c[0][6]);
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2],
src[3], weight, c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
};
#undef CALC_ONE_RESULT
template <BiasMode bias_mode>
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int16_t* bias, int16_t* dst,
const size_t oc, const size_t ic,
const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * oc_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
size_t oh_idx = 0;
for (; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_step, ld_oc);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_step, ld_oc);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc);
}
}
}
}
template <BiasMode bias_mode>
void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44(
const int8_t* src, const int8_t* filter, const int16_t* bias,
int16_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow) {
constexpr size_t filter_size = 3;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const int ld_oc = oh * ow * oc_step;
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
size_t oh_idx = 0;
for (; oh_idx + 1 < oh; oh_idx += 2) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_step, ld_oc,
ow * oc_step);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow4_oh2<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc,
ow * oc_step);
}
}
for (; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_step, ld_oc);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
ker_neon_dirctconv_3x3s1_oc4_ow4<bias_mode, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_remain, ld_oc);
}
}
}
}
template <BiasMode bias_mode, int filter_size>
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src,
const int8_t* filter,
const int16_t* bias, int16_t* dst,
const size_t oc, const size_t ic,
const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
const size_t img_stride = oh * ow;
const int ld_dst_oc = oh * ow * oc_step;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * iw + ow_idx) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_step, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * iw + ow_end) * ic_step;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
KerNeonDirectStride1Int8<bias_mode, filter_size>::impl(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ow_remain, ld_dst_oc);
}
}
}
}
} // namespace
namespace int8x8x16_direct_nchw44 {
template <BiasMode bias_mode, int filter_size>
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride1_2x2_int8_nchw44<bias_mode>(src, filter, bias, dst,
oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 3, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride1_3x3_int8x8x16_oh2_nchw44<bias_mode>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>;
#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode)
#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)
DISPATCH_CONV_KERN(1);
} // namespace int8x8x16_direct_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8_direct_nchw44_s1_armv7.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
#if MEGDNN_ARMV7
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
#define INIT_SUM() \
int16x4_t init_sum; \
if (bias_mode == BiasMode::NO_BIAS) { \
init_sum = vdup_n_s16(0); \
} else { \
init_sum = vld1_s16(bias_ptr); \
}
#define STORE_1_LINE_RESULT() \
switch (remain_w) { \
case 8: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
break; \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
break; \
case 5: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1_s16(dst_ptr + 16, c[0][4]); \
break; \
case 6: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
break; \
case 7: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + 24, c[0][6]); \
break; \
default: \
megdnn_assert(0, "oc 1 error remainw"); \
};
#define STORE_2_LINE_RESULT() \
switch (remain_w) { \
case 8: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
vst1q_s16(dst_ptr + ld_dst_oc + 24, \
vcombine_s16(c[1][6], c[1][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \
break; \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
break; \
case 5: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1_s16(dst_ptr + 16, c[0][4]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \
break; \
case 6: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
break; \
case 7: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + 24, c[0][6]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \
break; \
default: \
megdnn_assert(0, "oc 2 error remainw"); \
break; \
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int src_expand_size = 4;
const int ic_stride = ih * iw * src_expand_size;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
int16x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[4];
INIT_SUM();
#define cb(_i) \
c[0][_i] = init_sum; \
c[1][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* src_row0 = src_ptr + ic_idx * ic_stride +
0 * iw * ic_step * src_expand_size;
const int8_t* src_row1 = src_ptr + ic_idx * ic_stride +
1 * iw * ic_step * src_expand_size;
src[0] = vld1q_s8(src_row0);
src[1] = vld1q_s8(src_row0 + 16);
weight[0][0] = vld1q_s8(weight_ptr);
weight[0][1] = vld1q_s8(weight_ptr + 16);
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 16);
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
int16x8_t tmp0;
src[2] = vld1q_s8(src_row0 + 2 * 16);
src[3] = vld1q_s8(src_row0 + 3 * 16);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][1]);
src[0] = vld1q_s8(src_row0 + 4 * 16);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][2]);
src[1] = vld1q_s8(src_row0 + 5 * 16);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][3]);
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][3]);
src[2] = vld1q_s8(src_row0 + 6 * 16);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][4]);
src[3] = vld1q_s8(src_row0 + 7 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][5]);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][5]);
src[0] = vld1q_s8(src_row0 + 8 * 16);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][6]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][6]);
src[1] = vld1q_s8(src_row1 + 0 * 16);
src[2] = vld1q_s8(src_row1 + 1 * 16);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][7]);
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][7]);
weight[0][0] = vld1q_s8(weight_ptr + 32);
weight[0][1] = vld1q_s8(weight_ptr + 48);
src[3] = vld1q_s8(src_row1 + 2 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][0]);
weight[1][0] = vld1q_s8(weight_ptr + ld_weight_oc4 + 32);
weight[1][1] = vld1q_s8(weight_ptr + ld_weight_oc4 + 48);
src[0] = vld1q_s8(src_row1 + 3 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][0]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][1]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][1]);
src[1] = vld1q_s8(src_row1 + 4 * 16);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][2]);
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][2]);
src[2] = vld1q_s8(src_row1 + 5 * 16);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][3]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][3]);
src[3] = vld1q_s8(src_row1 + 6 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][4]);
CALC_ONE_RESULT(src[1], src[2], weight[1][0], weight[1][1], c[1][4]);
src[0] = vld1q_s8(src_row1 + 7 * 16);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][5]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][5]);
src[1] = vld1q_s8(src_row1 + 8 * 16);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][6]);
CALC_ONE_RESULT(src[3], src[0], weight[1][0], weight[1][1], c[1][6]);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][7]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][7]);
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_2_LINE_RESULT();
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int src_expand_size = 4;
const int ic_stride = ih * iw * src_expand_size;
int16x4_t c[1][8];
int8x16_t weight[1][2];
int8x16_t src[4];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * src_expand_size;
src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 16);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
int16x8_t tmp0;
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][0]);
src[0] = vld1q_s8(src_ic_0_3 + 4 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1],
c[0][1]);
src[1] = vld1q_s8(src_ic_0_3 + 5 * 16);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][2]);
src[2] = vld1q_s8(src_ic_0_3 + 6 * 16);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1],
c[0][3]);
src[3] = vld1q_s8(src_ic_0_3 + 7 * 16);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][4]);
src[0] = vld1q_s8(src_ic_0_3 + 8 * 16);
CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1],
c[0][5]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][6]);
CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1],
c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
#undef CALC_ONE_RESULT
template <BiasMode bias_mode, int remain_w, int filter_size>
struct KerNeonDirectStride1Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc);
};
template <BiasMode bias_mode, int remain_w>
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 3> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int src_expand_size = 4;
const int ic_stride = ih * iw * src_expand_size;
int16x4_t c[1][8];
int8x16_t weight[3];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * src_expand_size;
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w0, _w1, _w2, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w2)); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1],
weight[2], c[0][0]);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1],
weight[2], c[0][1]);
src[0] = vld1q_s8(src_ic_0_3 + 5 * 16);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1],
weight[2], c[0][2]);
src[1] = vld1q_s8(src_ic_0_3 + 6 * 16);
CALC_ONE_RESULT(src[3], src[4], src[0], weight[0], weight[1],
weight[2], c[0][3]);
src[2] = vld1q_s8(src_ic_0_3 + 7 * 16);
CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], weight[1],
weight[2], c[0][4]);
src[3] = vld1q_s8(src_ic_0_3 + 8 * 16);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1],
weight[2], c[0][5]);
src[4] = vld1q_s8(src_ic_0_3 + 9 * 16);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1],
weight[2], c[0][6]);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1],
weight[2], c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
};
#undef CALC_ONE_RESULT
template <BiasMode bias_mode, int remain_w>
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 5> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int src_expand_size = 4;
const int ic_stride = ih * iw * src_expand_size;
int16x4_t c[1][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * src_expand_size;
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \
_w4, _c) \
do { \
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \
tmp0 = vaddq_s16(tmp0, tmp1); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][0]);
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][1]);
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][3]);
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][4]);
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][5]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][6]);
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1],
weight[0], weight[1], weight[2], weight[3],
weight[4], c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
};
#undef CALC_ONE_RESULT
template <BiasMode bias_mode, int remain_w>
struct KerNeonDirectStride1Int8<bias_mode, remain_w, 7> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int src_expand_size = 4;
const int ic_stride = ih * iw * src_expand_size;
int16x4_t c[1][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * src_expand_size;
src[0] = vld1q_s8(src_ic_0_3 + 0 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \
_c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5],
src[6], weight, c[0][0]);
CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6],
src[7], weight, c[0][1]);
src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7],
src[8], weight, c[0][2]);
CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8],
src[9], weight, c[0][3]);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 13 * 16);
CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9],
src[0], weight, c[0][4]);
CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0],
src[1], weight, c[0][5]);
CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1],
src[2], weight, c[0][6]);
CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2],
src[3], weight, c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
};
template <BiasMode bias_mode>
void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44(
const int8_t* src, const int8_t* filter, const int16_t* bias,
int16_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t src_expand_size = 4;
const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * oc_step;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr,
int ic, int ih, int iw, int ld_dst_oc)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, step, \
filter_size>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, step, \
filter_size>; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
size_t oh_idx = 0;
for (; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, ow_step, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_oc);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, ow_step, filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_oc);
}
}
}
}
#undef CALC_ONE_RESULT
template <BiasMode bias_mode, int filter_size>
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src,
const int8_t* filter,
const int16_t* bias, int16_t* dst,
const size_t oc, const size_t ic,
const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t src_expand_size = 4;
const size_t img_stride = oh * ow;
const int ld_dst_oc = oh * ow * oc_step;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr,
int ic, int ih, int iw, int ld_dst_oc)>;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_small_oc_remain = KerNeonDirectStride1Int8<bias_mode, step, \
filter_size>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride1Int8<bias_mode, ow_step, filter_size>::impl(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * src_expand_size;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc);
}
}
}
}
} // namespace
namespace int8x8x16_direct_nchw44 {
template <BiasMode bias_mode, int filter_size>
struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride1_int8_nchw44_kern<bias_mode, filter_size>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride1_2x2_int8_oc8_ow8_nchw44<bias_mode>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>;
#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode)
#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)
DISPATCH_CONV_KERN(1);
} // namespace int8x8x16_direct_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
#define INIT_SUM() \
int16x4_t init_sum; \
if (bias_mode == BiasMode::NO_BIAS) { \
init_sum = vdup_n_s16(0); \
} else { \
init_sum = vld1_s16(bias_ptr); \
}
#define STORE_1_LINE_RESULT() \
switch (remain_w) { \
case 8: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
break; \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
break; \
case 5: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1_s16(dst_ptr + 16, c[0][4]); \
break; \
case 6: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
break; \
case 7: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + 24, c[0][6]); \
break; \
default: \
megdnn_assert(0, "oc 1 error remainw"); \
break; \
};
#define STORE_1_LINE_RESULT_OW4() \
switch (remain_w) { \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
break; \
default: \
megdnn_assert(0, "oc 1 error remainw"); \
break; \
};
#define STORE_2_LINE_RESULT() \
switch (remain_w) { \
case 8: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
vst1q_s16(dst_ptr + ld_dst_oc + 24, \
vcombine_s16(c[1][6], c[1][7])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \
break; \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
break; \
case 5: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1_s16(dst_ptr + 16, c[0][4]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \
break; \
case 6: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
break; \
case 7: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \
vst1_s16(dst_ptr + 24, c[0][6]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
vst1q_s16(dst_ptr + ld_dst_oc + 16, \
vcombine_s16(c[1][4], c[1][5])); \
vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \
break; \
default: \
megdnn_assert(0, "oc 2 error remainw"); \
break; \
}
#define STORE_2_LINE_RESULT_OW4() \
switch (remain_w) { \
case 4: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1q_s16(dst_ptr + ld_dst_oc + 8, \
vcombine_s16(c[1][2], c[1][3])); \
break; \
case 1: \
vst1_s16(dst_ptr, c[0][0]); \
vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \
break; \
case 2: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
break; \
case 3: \
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \
vst1_s16(dst_ptr + 8, c[0][2]); \
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \
vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \
break; \
default: \
megdnn_assert(0, "oc 2 error remainw"); \
break; \
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[4];
INIT_SUM();
#define cb(_i) \
c[0][_i] = init_sum; \
c[1][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
#define CALC_ONE_RESULT(_src0, _src1, _w0, _w1, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w0)); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w1)); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][0]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1],
c[1][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][1]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1],
c[1][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][2]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1],
c[1][2]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 36, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][3]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1],
c[1][3]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 40, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 44, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][4]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1],
c[1][4]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 48, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 52, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][5]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1],
c[1][5]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 56, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 60, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1],
c[0][6]);
CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1],
c[1][6]);
CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1],
c[0][7]);
CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1],
c[1][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_2_LINE_RESULT();
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[1][8];
int8x16_t weight[2];
int8x16_t src[4];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(8, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
int16x8_t tmp0;
CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][2]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 36, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][3]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 40, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 44, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][4]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 48, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 52, idx);
CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][5]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 56, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 60, idx);
CALC_ONE_RESULT(src[0], src[1], weight[0], weight[1], c[0][6]);
CALC_ONE_RESULT(src[2], src[3], weight[0], weight[1], c[0][7]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT();
}
#undef CALC_ONE_RESULT
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \
tmp0 = vaddq_s16(tmp0, tmp1); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s2_oc8_ow4(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[2][4];
int8x16_t weight[2][3];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) \
c[0][_i] = init_sum; \
c[1][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[0][2] = vld1q_s8(read_weight_ptr + 32);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
weight[1][2] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 32);
int16x8_t tmp0, tmp1;
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[1], c[1][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][1]);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[1], c[1][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], c[0][2]);
CALC_ONE_RESULT(src[4], src[0], src[1], weight[1], c[1][2]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][3]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[1], c[1][3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3]));
vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1]));
vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3]));
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic,
int ih, int iw,
int ld_dst_oc) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
const int ld_weight_oc4 = oc_step * fh * fw * ic;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[2][4];
int8x16_t weight[2][3];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) \
c[0][_i] = init_sum; \
c[1][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[0][2] = vld1q_s8(read_weight_ptr + 32);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);
weight[1][2] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 32);
int16x8_t tmp0, tmp1;
CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], c[0][0]);
CALC_ONE_RESULT(src[0], src[1], src[2], weight[1], c[1][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], c[0][1]);
CALC_ONE_RESULT(src[2], src[3], src[4], weight[1], c[1][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], c[0][2]);
CALC_ONE_RESULT(src[4], src[0], src[1], weight[1], c[1][2]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], c[0][3]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight[1], c[1][3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_2_LINE_RESULT_OW4();
}
#undef CALC_ONE_RESULT
#define CALC_ONE_RESULT(_src0, _src1, _src2, _w, _c) \
do { \
int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s2_oc4_ow4(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic, int ih,
int iw, int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[1][4];
int8x16_t weight[3];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
CALC_ONE_RESULT(src[0], src[1], src[2], weight, c[0][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight, c[0][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
CALC_ONE_RESULT(src[4], src[0], src[1], weight, c[0][2]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight, c[0][3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3]));
}
template <BiasMode bias_mode, int remain_w, int filter_size>
static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int16_t* bias_ptr,
int16_t* dst_ptr, int ic,
int ih, int iw,
int /*ld_dst_oc*/) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
int16x4_t c[1][4];
int8x16_t weight[3];
int8x16_t src[5];
INIT_SUM();
#define cb(_i) c[0][_i] = init_sum;
UNROLL_CALL_RAW(4, cb);
#undef cb
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 =
src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step;
src[0] = vld_dup_tbl_s32(src_ic_0_3, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 4, idx);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx);
src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx);
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;
weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
CALC_ONE_RESULT(src[0], src[1], src[2], weight, c[0][0]);
src[0] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx);
src[1] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx);
CALC_ONE_RESULT(src[2], src[3], src[4], weight, c[0][1]);
src[2] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx);
src[3] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx);
CALC_ONE_RESULT(src[4], src[0], src[1], weight, c[0][2]);
CALC_ONE_RESULT(src[1], src[2], src[3], weight, c[0][3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
STORE_1_LINE_RESULT_OW4();
}
#undef CALC_ONE_RESULT
template <BiasMode bias_mode>
void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int16_t* bias, int16_t* dst,
const size_t oc, const size_t ic,
const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oh * ow * oc_step;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr,
int ic, int ih, int iw, int ld_dst_oc)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, step, \
filter_size>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, step, \
filter_size>; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, ow_step,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, ow_step,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc);
}
}
}
}
template <BiasMode bias_mode>
void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int16_t* bias, int16_t* dst,
const size_t oc, const size_t ic,
const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
constexpr size_t filter_size = 3;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 4;
constexpr size_t ow_step4 = 4;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oh * ow * oc_step;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int16_t* bias_ptr, int16_t* dst_ptr,
int ic, int ih, int iw, int ld_dst_oc)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
ker_neon_dirctconv_3x3s2_oc8_ow4_remain<bias_mode, step, \
filter_size>; \
kern_small_oc_remain = \
ker_neon_dirctconv_3x3s2_oc4_ow4_remain<bias_mode, step, \
filter_size>; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step4) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s2_oc8_ow4<bias_mode, ow_step,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_3x3s2_oc4_ow4<bias_mode, ow_step,
filter_size>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc);
}
}
}
}
#undef CALC_ONE_RESULT
#undef LOAD_SRC
template <BiasMode bias_mode>
void conv_direct_stride2_5x5_int8x8x16_nchw44(
const int8_t* src, const int8_t* filter, const int16_t* bias,
int16_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow) {
constexpr size_t filter_size = 5;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step1 = 1;
constexpr size_t ow_step = 4;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
const size_t remain_w = ow & 3;
const size_t out_img_stride = oh * ow;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
size_t oc_idx = 0;
for (; oc_idx + 3 < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
const int16_t* bias_ptr = bias + oc_idx;
int16x4_t init_sum;
if (bias_mode == BiasMode::NO_BIAS) {
init_sum = vdup_n_s16(0);
} else {
init_sum = vld1_s16(bias_ptr);
}
size_t oh_idx = 0;
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w, _c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \
tmp0 = vaddq_s16(tmp0, tmp1); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
for (; oh_idx < oh; oh_idx += oh_step1) {
size_t ow_idx = 0;
for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c[1][4];
const int8_t* src_ptr = src + src_offset;
int16_t* dst_ptr = dst + dst_offset;
const int8_t* weight_ptr = filter + weight_offset;
c[0][0] = init_sum;
c[0][1] = init_sum;
c[0][2] = init_sum;
c[0][3] = init_sum;
#if MEGDNN_AARCH64
int8x16_t weight[3][5];
int8x16_t ssrc[2][5];
#else
int8x16_t weight[1][5];
int8x16_t ssrc[1][9];
#endif
for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
const int8_t* src_row3 =
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step;
const int8_t* src_row4 =
src_ptr + ic_idx * ic_stride + 4 * iw * ic_step;
#if MEGDNN_AARCH64
#define LOAD_SRC(_src, _src_ptr) \
_src[0] = vld_dup_tbl_s32(_src_ptr, idx); \
_src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \
_src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \
_src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \
_src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx);
#define LOAD_WEIGHT(_w, _w_ptr, _id0, _id1, _id2, _id3, _id4) \
_w[0] = vld1q_s8(_w_ptr + _id0 * 16); \
_w[1] = vld1q_s8(_w_ptr + _id1 * 16); \
_w[2] = vld1q_s8(_w_ptr + _id2 * 16); \
_w[3] = vld1q_s8(_w_ptr + _id3 * 16); \
_w[4] = vld1q_s8(_w_ptr + _id4 * 16);
#define CALC_4_RESULT(_src, _w, _src_ptr) \
CALC_ONE_RESULT(_src[0], _src[1], _src[2], _src[3], _src[4], _w, c[0][0]); \
_src[0] = vld_dup_tbl_s32(_src_ptr + 20, idx); \
_src[1] = vld_dup_tbl_s32(_src_ptr + 24, idx); \
CALC_ONE_RESULT(_src[2], _src[3], _src[4], _src[0], _src[1], _w, c[0][1]); \
_src[2] = vld_dup_tbl_s32(_src_ptr + 28, idx); \
_src[3] = vld_dup_tbl_s32(_src_ptr + 32, idx); \
CALC_ONE_RESULT(_src[4], _src[0], _src[1], _src[2], _src[3], _w, c[0][2]); \
_src[4] = vld_dup_tbl_s32(_src_ptr + 36, idx); \
_src[0] = vld_dup_tbl_s32(_src_ptr + 40, idx); \
CALC_ONE_RESULT(_src[1], _src[2], _src[3], _src[4], _src[0], _w, c[0][3]);
int16x8_t tmp0, tmp1;
LOAD_SRC(ssrc[0], src_row0);
LOAD_WEIGHT(weight[0], weight_ptr, 0, 1, 2, 3, 4);
LOAD_WEIGHT(weight[1], weight_ptr, 5, 6, 7, 8, 9);
CALC_4_RESULT(ssrc[0], weight[0], src_row0);
LOAD_SRC(ssrc[1], src_row1);
LOAD_WEIGHT(weight[2], weight_ptr, 10, 11, 12, 13, 14);
LOAD_SRC(ssrc[0], src_row2);
CALC_4_RESULT(ssrc[1], weight[1], src_row1);
LOAD_SRC(ssrc[1], src_row3);
LOAD_WEIGHT(weight[0], weight_ptr, 15, 16, 17, 18, 19);
CALC_4_RESULT(ssrc[0], weight[2], src_row2);
LOAD_SRC(ssrc[0], src_row4);
LOAD_WEIGHT(weight[1], weight_ptr, 20, 21, 22, 23, 24);
CALC_4_RESULT(ssrc[1], weight[0], src_row3);
CALC_4_RESULT(ssrc[0], weight[1], src_row4);
#else
#define LOAD_SRC(_src_ptr) \
ssrc[0][0] = vld_dup_tbl_s32(_src_ptr, idx); \
ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \
ssrc[0][2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \
ssrc[0][3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \
ssrc[0][4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \
ssrc[0][5] = vld_dup_tbl_s32(_src_ptr + 20, idx); \
ssrc[0][6] = vld_dup_tbl_s32(_src_ptr + 24, idx); \
ssrc[0][7] = vld_dup_tbl_s32(_src_ptr + 28, idx); \
ssrc[0][8] = vld_dup_tbl_s32(_src_ptr + 32, idx);
#define LOAD_WEIGHT(_w_ptr, _id0, _id1, _id2, _id3, _id4) \
weight[0][0] = vld1q_s8(_w_ptr + _id0 * 16); \
weight[0][1] = vld1q_s8(_w_ptr + _id1 * 16); \
weight[0][2] = vld1q_s8(_w_ptr + _id2 * 16); \
weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \
weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16);
#define CALC_4_RESULT(_src_ptr) \
CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \
ssrc[0][4], weight[0], c[0][0]); \
ssrc[0][0] = vld_dup_tbl_s32(_src_ptr + 36, idx); \
ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 40, idx); \
CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \
ssrc[0][6], weight[0], c[0][1]); \
CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \
ssrc[0][8], weight[0], c[0][2]); \
CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \
ssrc[0][1], weight[0], c[0][3]);
int16x8_t tmp0, tmp1;
LOAD_WEIGHT(weight_ptr, 0, 1, 2, 3, 4);
LOAD_SRC(src_row0);
CALC_4_RESULT(src_row0);
LOAD_WEIGHT(weight_ptr, 5, 6, 7, 8, 9);
LOAD_SRC(src_row1);
CALC_4_RESULT(src_row1);
LOAD_WEIGHT(weight_ptr, 10, 11, 12, 13, 14);
LOAD_SRC(src_row2);
CALC_4_RESULT(src_row2);
LOAD_WEIGHT(weight_ptr, 15, 16, 17, 18, 19);
LOAD_SRC(src_row3);
CALC_4_RESULT(src_row3);
LOAD_WEIGHT(weight_ptr, 20, 21, 22, 23, 24);
LOAD_SRC(src_row4);
CALC_4_RESULT(src_row4);
#endif
weight_ptr += fh * fw * ld_weight_ic4;
}
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3]));
}
if (remain_w > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c[1][3];
const int8_t* src_ptr = src + src_offset;
int16_t* dst_ptr = dst + dst_offset;
const int8_t* weight_ptr = filter + weight_offset;
c[0][0] = init_sum;
c[0][1] = init_sum;
c[0][2] = init_sum;
#if MEGDNN_AARCH64
int8x16_t weight[3][5];
int8x16_t ssrc[2][5];
#else
int8x16_t weight[1][5];
int8x16_t ssrc[1][9];
#endif
for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
const int8_t* src_row3 =
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step;
const int8_t* src_row4 =
src_ptr + ic_idx * ic_stride + 4 * iw * ic_step;
#if MEGDNN_AARCH64
#define LOAD_SRC(_src, _src_ptr) \
_src[0] = vld_dup_tbl_s32(_src_ptr, idx); \
_src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \
_src[2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \
_src[3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \
_src[4] = vld_dup_tbl_s32(_src_ptr + 16, idx);
#define LOAD_WEIGHT(_w, _w_ptr, _id0, _id1, _id2, _id3, _id4) \
_w[0] = vld1q_s8(_w_ptr + _id0 * 16); \
_w[1] = vld1q_s8(_w_ptr + _id1 * 16); \
_w[2] = vld1q_s8(_w_ptr + _id2 * 16); \
_w[3] = vld1q_s8(_w_ptr + _id3 * 16); \
_w[4] = vld1q_s8(_w_ptr + _id4 * 16);
#define CALC_3_RESULT(_src, _w, _src_ptr) \
CALC_ONE_RESULT(_src[0], _src[1], _src[2], _src[3], _src[4], _w, c[0][0]); \
_src[0] = vld_dup_tbl_s32(_src_ptr + 20, idx); \
_src[1] = vld_dup_tbl_s32(_src_ptr + 24, idx); \
CALC_ONE_RESULT(_src[2], _src[3], _src[4], _src[0], _src[1], _w, c[0][1]); \
_src[2] = vld_dup_tbl_s32(_src_ptr + 28, idx); \
_src[3] = vld_dup_tbl_s32(_src_ptr + 32, idx); \
CALC_ONE_RESULT(_src[4], _src[0], _src[1], _src[2], _src[3], _w, c[0][2]);
int16x8_t tmp0, tmp1;
LOAD_SRC(ssrc[0], src_row0);
LOAD_WEIGHT(weight[0], weight_ptr, 0, 1, 2, 3, 4);
LOAD_WEIGHT(weight[1], weight_ptr, 5, 6, 7, 8, 9);
CALC_3_RESULT(ssrc[0], weight[0], src_row0);
LOAD_SRC(ssrc[1], src_row1);
LOAD_WEIGHT(weight[2], weight_ptr, 10, 11, 12, 13, 14);
LOAD_SRC(ssrc[0], src_row2);
CALC_3_RESULT(ssrc[1], weight[1], src_row1);
LOAD_SRC(ssrc[1], src_row3);
LOAD_WEIGHT(weight[0], weight_ptr, 15, 16, 17, 18, 19);
CALC_3_RESULT(ssrc[0], weight[2], src_row2);
LOAD_SRC(ssrc[0], src_row4);
LOAD_WEIGHT(weight[1], weight_ptr, 20, 21, 22, 23, 24);
CALC_3_RESULT(ssrc[1], weight[0], src_row3);
CALC_3_RESULT(ssrc[0], weight[1], src_row4);
#else
#define LOAD_SRC(_src_ptr) \
ssrc[0][0] = vld_dup_tbl_s32(_src_ptr, idx); \
ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \
ssrc[0][2] = vld_dup_tbl_s32(_src_ptr + 8, idx); \
ssrc[0][3] = vld_dup_tbl_s32(_src_ptr + 12, idx); \
ssrc[0][4] = vld_dup_tbl_s32(_src_ptr + 16, idx); \
ssrc[0][5] = vld_dup_tbl_s32(_src_ptr + 20, idx); \
ssrc[0][6] = vld_dup_tbl_s32(_src_ptr + 24, idx); \
ssrc[0][7] = vld_dup_tbl_s32(_src_ptr + 28, idx); \
ssrc[0][8] = vld_dup_tbl_s32(_src_ptr + 32, idx);
#define LOAD_WEIGHT(_w_ptr, _id0, _id1, _id2, _id3, _id4) \
weight[0][0] = vld1q_s8(_w_ptr + _id0 * 16); \
weight[0][1] = vld1q_s8(_w_ptr + _id1 * 16); \
weight[0][2] = vld1q_s8(_w_ptr + _id2 * 16); \
weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \
weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16);
#define CALC_3_RESULT(_src_ptr) \
CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \
ssrc[0][4], weight[0], c[0][0]); \
CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \
ssrc[0][6], weight[0], c[0][1]); \
CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \
ssrc[0][8], weight[0], c[0][2]);
int16x8_t tmp0, tmp1;
LOAD_WEIGHT(weight_ptr, 0, 1, 2, 3, 4);
LOAD_SRC(src_row0);
CALC_3_RESULT(src_row0);
LOAD_WEIGHT(weight_ptr, 5, 6, 7, 8, 9);
LOAD_SRC(src_row1);
CALC_3_RESULT(src_row1);
LOAD_WEIGHT(weight_ptr, 10, 11, 12, 13, 14);
LOAD_SRC(src_row2);
CALC_3_RESULT(src_row2);
LOAD_WEIGHT(weight_ptr, 15, 16, 17, 18, 19);
LOAD_SRC(src_row3);
CALC_3_RESULT(src_row3);
LOAD_WEIGHT(weight_ptr, 20, 21, 22, 23, 24);
LOAD_SRC(src_row4);
CALC_3_RESULT(src_row4);
#endif
weight_ptr += fh * fw * ld_weight_ic4;
}
switch (remain_w) {
case 1:
vst1_s16(dst_ptr, c[0][0]);
break;
case 2:
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
break;
case 3:
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
vst1_s16(dst_ptr + 8, c[0][2]);
break;
default:
megdnn_throw("invalid remain_w");
break;
}
}
}
}
}
#undef CALC_4_RESULT
#undef LOAD_SRC
#undef LOAD_WEIGHT
#undef CALC_ONE_RESULT
template <BiasMode bias_mode>
void conv_direct_stride2_7x7_int8x8x16_nchw44(
const int8_t* src, const int8_t* filter, const int16_t* bias,
int16_t* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow) {
constexpr size_t filter_size = 7;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step1 = 1;
constexpr size_t ow_step = 4;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
const size_t out_img_stride = oh * ow;
static const uint8_t idx_buffer[16] = {0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3};
static uint8x16_t idx = vld1q_u8(idx_buffer);
size_t oc_idx = 0;
for (; oc_idx + 3 < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
const int16_t* bias_ptr = bias + oc_idx;
int16x4_t init_sum;
if (bias_mode == BiasMode::NO_BIAS) {
init_sum = vdup_n_s16(0);
} else {
init_sum = vld1_s16(bias_ptr);
}
size_t oh_idx = 0;
#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \
_c) \
do { \
tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \
tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src5), vget_high_s8(_w[5])); \
tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \
tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \
tmp0 = vaddq_s16(tmp0, tmp1); \
_c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \
} while (0);
for (; oh_idx < oh; oh_idx += oh_step1) {
size_t ow_idx = 0;
for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c[1][4];
int8x16_t weight[1][7];
int8x16_t ssrc[1][9];
const int8_t* src_ptr = src + src_offset;
int16_t* dst_ptr = dst + dst_offset;
const int8_t* weight_ptr = filter + weight_offset;
c[0][0] = init_sum;
c[0][1] = init_sum;
c[0][2] = init_sum;
c[0][3] = init_sum;
for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
const int8_t* src_row3 =
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step;
const int8_t* src_row4 =
src_ptr + ic_idx * ic_stride + 4 * iw * ic_step;
const int8_t* src_row5 =
src_ptr + ic_idx * ic_stride + 5 * iw * ic_step;
const int8_t* src_row6 =
src_ptr + ic_idx * ic_stride + 6 * iw * ic_step;
#define LOAD_SRC(_src) \
ssrc[0][0] = vld_dup_tbl_s32(_src, idx); \
ssrc[0][1] = vld_dup_tbl_s32(_src + 4, idx); \
ssrc[0][2] = vld_dup_tbl_s32(_src + 8, idx); \
ssrc[0][3] = vld_dup_tbl_s32(_src + 12, idx); \
ssrc[0][4] = vld_dup_tbl_s32(_src + 16, idx); \
ssrc[0][5] = vld_dup_tbl_s32(_src + 20, idx); \
ssrc[0][6] = vld_dup_tbl_s32(_src + 24, idx);
#define LOAD_WEIGHT(_id0, _id1, _id2, _id3, _id4, _id5, _id6) \
weight[0][0] = vld1q_s8(weight_ptr + _id0 * 16); \
weight[0][1] = vld1q_s8(weight_ptr + _id1 * 16); \
weight[0][2] = vld1q_s8(weight_ptr + _id2 * 16); \
weight[0][3] = vld1q_s8(weight_ptr + _id3 * 16); \
weight[0][4] = vld1q_s8(weight_ptr + _id4 * 16); \
weight[0][5] = vld1q_s8(weight_ptr + _id5 * 16); \
weight[0][6] = vld1q_s8(weight_ptr + _id6 * 16);
#define CALC_4_RESULT(_row) \
CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \
ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c[0][0]); \
\
ssrc[0][7] = vld_dup_tbl_s32(_row + 28, idx); \
ssrc[0][8] = vld_dup_tbl_s32(_row + 32, idx); \
CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \
ssrc[0][6], ssrc[0][7], ssrc[0][8], weight[0], c[0][1]); \
\
ssrc[0][0] = vld_dup_tbl_s32(_row + 36, idx); \
ssrc[0][1] = vld_dup_tbl_s32(_row + 40, idx); \
\
CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \
ssrc[0][8], ssrc[0][0], ssrc[0][1], weight[0], c[0][2]); \
ssrc[0][2] = vld_dup_tbl_s32(_row + 44, idx); \
ssrc[0][3] = vld_dup_tbl_s32(_row + 48, idx); \
\
CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \
ssrc[0][1], ssrc[0][2], ssrc[0][3], weight[0], c[0][3]);
int16x8_t tmp0, tmp1;
LOAD_SRC(src_row0);
LOAD_WEIGHT(0, 1, 2, 3, 4, 5, 6);
CALC_4_RESULT(src_row0);
LOAD_SRC(src_row1);
LOAD_WEIGHT(7, 8, 9, 10, 11, 12, 13);
CALC_4_RESULT(src_row1);
LOAD_SRC(src_row2);
LOAD_WEIGHT(14, 15, 16, 17, 18, 19, 20);
CALC_4_RESULT(src_row2);
LOAD_SRC(src_row3);
LOAD_WEIGHT(21, 22, 23, 24, 25, 26, 27);
CALC_4_RESULT(src_row3);
LOAD_SRC(src_row4);
LOAD_WEIGHT(28, 29, 30, 31, 32, 33, 34);
CALC_4_RESULT(src_row4);
LOAD_SRC(src_row5);
LOAD_WEIGHT(35, 36, 37, 38, 39, 40, 41);
CALC_4_RESULT(src_row5);
LOAD_SRC(src_row6);
LOAD_WEIGHT(42, 43, 44, 45, 46, 47, 48);
CALC_4_RESULT(src_row6);
weight_ptr += fh * fw * ld_weight_ic4;
}
vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1]));
vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3]));
}
for (; ow_idx < ow; ow_idx++) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
constexpr int ld_weight_ic4 = 16;
const int ic_stride = ih * iw;
int16x4_t c = init_sum;
int8x16_t weight[1][7];
int8x16_t ssrc[1][7];
const int8_t* src_ptr = src + src_offset;
int16_t* dst_ptr = dst + dst_offset;
const int8_t* weight_ptr = filter + weight_offset;
for (size_t ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const int8_t* src_row0 =
src_ptr + ic_idx * ic_stride + 0 * iw * ic_step;
const int8_t* src_row1 =
src_ptr + ic_idx * ic_stride + 1 * iw * ic_step;
const int8_t* src_row2 =
src_ptr + ic_idx * ic_stride + 2 * iw * ic_step;
const int8_t* src_row3 =
src_ptr + ic_idx * ic_stride + 3 * iw * ic_step;
const int8_t* src_row4 =
src_ptr + ic_idx * ic_stride + 4 * iw * ic_step;
const int8_t* src_row5 =
src_ptr + ic_idx * ic_stride + 5 * iw * ic_step;
const int8_t* src_row6 =
src_ptr + ic_idx * ic_stride + 6 * iw * ic_step;
#define CALC_1_RESULT(_row) \
CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \
ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c);
int16x8_t tmp0, tmp1;
LOAD_SRC(src_row0);
LOAD_WEIGHT(0, 1, 2, 3, 4, 5, 6);
CALC_1_RESULT(src_row0);
LOAD_SRC(src_row1);
LOAD_WEIGHT(7, 8, 9, 10, 11, 12, 13);
CALC_1_RESULT(src_row1);
LOAD_SRC(src_row2);
LOAD_WEIGHT(14, 15, 16, 17, 18, 19, 20);
CALC_1_RESULT(src_row2);
LOAD_SRC(src_row3);
LOAD_WEIGHT(21, 22, 23, 24, 25, 26, 27);
CALC_1_RESULT(src_row3);
LOAD_SRC(src_row4);
LOAD_WEIGHT(28, 29, 30, 31, 32, 33, 34);
CALC_1_RESULT(src_row4);
LOAD_SRC(src_row5);
LOAD_WEIGHT(35, 36, 37, 38, 39, 40, 41);
CALC_1_RESULT(src_row5);
LOAD_SRC(src_row6);
LOAD_WEIGHT(42, 43, 44, 45, 46, 47, 48);
CALC_1_RESULT(src_row6);
weight_ptr += fh * fw * ld_weight_ic4;
}
vst1_s16(dst_ptr, c);
}
}
}
}
#undef CALC_ONE_RESULT
#undef CALC_1_RESULT
#undef CALC_4_RESULT
#undef LOAD_SRC
#undef LOAD_WEIGHT
} // namespace
namespace int8x8x16_direct_nchw44 {
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 2, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride2_2x2_int8_nchw44<bias_mode>(src, filter, bias, dst,
oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 3, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride2_3x3_int8_nchw44<bias_mode>(src, filter, bias, dst,
oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 5, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride2_5x5_int8x8x16_nchw44<bias_mode>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
template <BiasMode bias_mode>
struct ConvDirectInt8Nchw44Choose<bias_mode, 7, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int16_t* bias, int16_t* dst, const size_t oc,
const size_t ic, const size_t ih, const size_t iw,
const size_t oh, const size_t ow) {
conv_direct_stride2_7x7_int8x8x16_nchw44<bias_mode>(
src, filter, bias, dst, oc, ic, ih, iw, oh, ow);
}
};
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, filter_size, stride>;
#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode)
#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)
DISPATCH_CONV_KERN(2);
} // namespace int8x8x16_direct_nchw44
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -44,6 +44,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoQU8DirectStride1 qu8_direct_stride1;
AlgoS8DirectStride2 s8_direct_stride2;
AlgoS8DirectNCHW44 s8_direct_nchw44;
AlgoS8x8x16DirectNCHW44 s8x8x16_direct_nchw44;
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1;
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
......@@ -94,6 +95,7 @@ public:
direct_algos.emplace_back(&qu8_direct_stride1);
direct_algos.emplace_back(&s8_direct_stride2);
direct_algos.emplace_back(&s8_direct_nchw44);
direct_algos.emplace_back(&s8x8x16_direct_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1);
......
......@@ -39,6 +39,7 @@ private:
class AlgoS8DirectStride1;
class AlgoS8DirectStride2;
class AlgoS8DirectNCHW44;
class AlgoS8x8x16DirectNCHW44;
class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2;
......
......@@ -518,6 +518,116 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle,
}
}
void benchmark_nchw44_8x8x16_vs_8x8x32(const char* im2col_name, Handle* handle,
size_t kernel, size_t stride,
size_t pack_size = 1) {
megdnn_assert(stride == 1 || stride == 2, "only support stride 1 or 2");
std::vector<conv_bias::TestArg> args;
auto pack = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p) {
if (ic % pack_size != 0 || oc % pack_size != 0)
return;
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.format = param::ConvBias::Format::NCHW44;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = p;
param.pad_w = p;
param.sparse = param::ConvBias::Sparse::DENSE;
args.push_back(conv_bias::TestArg{
param,
TensorShape{1, ic / 4, h, w, 4},
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4},
{1, oc / 4, 1, 1, 4}});
};
pack(1, 64, 56, 56, kernel, 0);
pack(8, 64, 56, 56, kernel, 0);
pack(16, 64, 56, 56, kernel, 1);
pack(32, 64, 56, 56, kernel, 1);
pack(1, 64, 100, 100, kernel, 1);
pack(8, 64, 100, 100, kernel, 1);
pack(1, 64, 100, 100, kernel, 0);
pack(8, 64, 100, 100, kernel, 0);
pack(16, 64, 100, 100, kernel, 1);
pack(32, 64, 100, 100, kernel, 1);
pack(64, 64, 100, 100, kernel, 1);
pack(128, 64, 100, 100, kernel, 1);
pack(256, 64, 100, 100, kernel, 1);
pack(512, 64, 100, 100, kernel, 1);
pack(1024, 64, 100, 100, kernel, 1);
pack(1, 32, 200, 200, kernel, 1);
pack(8, 64, 200, 200, kernel, 1);
pack(1, 32, 200, 200, kernel, 0);
pack(8, 64, 200, 200, kernel, 0);
pack(16, 96, 200, 200, kernel, 1);
pack(32, 32, 200, 200, kernel, 1);
pack(64, 64, 200, 200, kernel, 1);
pack(128, 96, 200, 200, kernel, 1);
pack(1, 64, 10, 10, kernel, 1);
pack(8, 64, 10, 10, kernel, 1);
pack(16, 64, 10, 10, kernel, 1);
pack(32, 64, 10, 10, kernel, 1);
pack(64, 64, 10, 10, kernel, 1);
pack(128, 64, 10, 10, kernel, 1);
pack(256, 64, 10, 10, kernel, 1);
pack(512, 64, 10, 10, kernel, 1);
pack(1024, 64, 10, 10, kernel, 1);
using namespace conv_bias;
constexpr size_t RUN = 20;
Benchmarker<ConvBias> benchmark_im2col(handle);
benchmark_im2col.set_display(false);
benchmark_im2col.set_times(RUN);
Benchmarker<ConvBias> benchmark_8832(handle);
benchmark_8832.set_display(false);
benchmark_8832.set_times(RUN);
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Float32()},
{arg.filter, dtype::Float32()},
{arg.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 * 4 /
(1024 * 1024 * 1024) * 1e3;
benchmark_im2col.set_param(arg.param);
benchmark_im2col.set_dtype(0, dtype::Int8());
benchmark_im2col.set_dtype(1, dtype::Int8());
benchmark_im2col.set_dtype(2, dtype::Int16());
benchmark_im2col.set_dtype(4, dtype::Int16());
auto used_8816 =
algo_benchmark<ConvBias>(benchmark_im2col,
{arg.src, arg.filter, {}, {}, {}},
im2col_name) /
RUN;
benchmark_8832.set_param(arg.param);
benchmark_8832.set_dtype(0, dtype::QuantizedS8(2.5));
benchmark_8832.set_dtype(1, dtype::QuantizedS8(2.5));
benchmark_8832.set_dtype(2, dtype::QuantizedS32(6.25));
benchmark_8832.set_dtype(4, {});
auto used_8832 =
algo_benchmark<ConvBias>(benchmark_8832,
{arg.src, arg.filter, {}, {}, {}},
"S8_NCHW44_DIRECT") /
RUN;
printf("%s %s: 8816: %f ms %f GFlops ", arg.src.to_string().c_str(),
arg.filter.to_string().c_str(), used_8816,
computations / used_8816);
printf("%s %s: 8832: %f ms %f GFlops ", arg.src.to_string().c_str(),
arg.filter.to_string().c_str(), used_8832,
computations / used_8832);
printf("speedup %f \n", used_8832 / used_8816);
}
}
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name,
const char* im2col_name, Handle* handle,
size_t kernel, DType src_type,
......@@ -872,6 +982,28 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_MATMUL) {
#endif
#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE1) {
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 1,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 1,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 1,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 1,
4);
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_8X8X16_DIRECT_STRIDE2) {
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 2, 2,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 3, 2,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 5, 2,
4);
benchmark_nchw44_8x8x16_vs_8x8x32("S8x8x16_NCHW44_DIRECT", handle(), 7, 2,
4);
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32:1:2", handle(), 3);
......
......@@ -534,11 +534,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
checker_conv_bias_int8x8x16(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true),
handle(), "S8x8x16_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
checker_conv_bias_int8x8x16(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),
handle(), "S8x8x16_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
checker_conv_bias_qint8x8x32(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true),
handle(), "S8_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
checker_conv_bias_qint8x8x32(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册