提交 6c29548d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix nchw_nchw44 dot stride1 support

GitOrigin-RevId: c8d3d55b258e2a43c27b903808566f2ea1857842
上级 02cbb13b
/**
* \file dnn/src/arm_common/conv_bias/block_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
namespace megdnn {
namespace {
// block_helper is used to calculate oh block size
static inline int l2_block_helper(const int nthread, const int amount,
const int size_per_unit) {
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block = std::min(
amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
} // namespace
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -11,6 +11,7 @@
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
......@@ -26,22 +27,7 @@ using conv_fun = std::function<void(
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1)
namespace {
// block_helper is used to calculate oh block size
static inline int block_helper(const int nthread, const int amount,
const int size_per_unit) {
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block = std::min(
amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2) {
// border_size is used to avoid read illegal memory
......@@ -60,7 +46,7 @@ static void get_rectified_size(
ow2 = ow;
constexpr int cacheline = 64 / sizeof(float);
int block_oh =
block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2);
l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2);
auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]);
......@@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle,
const int group_id = ncb_index.ndrange_id[1];
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = block_helper(kern_param.nr_threads, oh2,
ic * iw * sizeof(float) * stride_h);
int oh_block = l2_block_helper(kern_param.nr_threads, oh2,
ic * iw * sizeof(float) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
......@@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int stride_h = param.filter_meta.stride[0];
int oh_block = block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
int oh_block = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
......
......@@ -133,6 +133,21 @@ public:
};
#if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; }
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
bool m_large_group;
......
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_dot)
namespace {
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2,
const int stride) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 2 * cacheline_size;
const int pack_iw_len = stride == 1 ? 4 : 1;
return round_up(
ic * ih2 * iw2 * pack_iw_len * (int)sizeof(int8_t) + border_size,
cacheline_size);
}
static inline size_t get_temp_bytes(const int iw, const int pw) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 1 * cacheline_size;
return round_up(iw + pw * 2, cacheline_size) + border_size;
}
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2) {
auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]);
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh = param.osz[0];
int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = iw + 2 * static_cast<int>(fm.padding[1]);
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int ic = fm.icpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int iw = param.isz[1];
int pw = param.filter_meta.padding[1];
int stride_w = param.filter_meta.stride[1];
int ih2, iw2;
get_rectified_size(param, ih2, iw2);
size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w);
size_t weight_size = fm.group * fm.icpg * fm.ocpg * fh * round_up(fw, 4);
size_t temp_size = 0;
if (fm.stride[0] == 1) {
temp_size = get_temp_bytes(iw, pw);
}
return {nullptr,
{src_size * param.nr_threads, weight_size,
temp_size * param.nr_threads}};
};
void do_weight_trans(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) {
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int fw2 = round_up(fw, 4);
bundle.set(kern_param.workspace_ptr);
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1));
auto origin_weight = kern_param.filter<dt_int8>();
pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh,
fw, fw2);
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange&, const CpuNDRange&) {
const int oh = kern_param.osz[0];
const int ow = kern_param.osz[1];
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int stride_h = kern_param.filter_meta.stride[0];
const int stride_w = kern_param.filter_meta.stride[1];
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2 = 0;
int iw2 = 0;
get_rectified_size(kern_param, ih2, iw2);
bundle.set(kern_param.workspace_ptr);
constexpr int pack_c = 4;
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = l2_block_helper(kern_param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0);
const int src_bottom_pad = std::max(
(oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph,
0);
const int remain_right_pad = std::max(iw2 - iw - pw, 0);
const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw;
const int8_t* origin_sptr =
static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, 0, 1, 1)) +
src_offset;
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w);
int8_t* sptr = reinterpret_cast<int8_t*>(bundle.get(0)) +
ncb_index.thread_id * src_size;
int8_t* tmp_ptr = nullptr;
if (stride == 1) {
const size_t tmp_size = get_temp_bytes(iw, pw);
tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
}
pack_src_int8_nchw_nchw44_dot<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw, tmp_ptr);
const int8_t* fptr =
reinterpret_cast<int8_t*>(bundle.get(1)) + oc_idx * fh * fw * ic;
int8_t* dst = kern_param.dst<int8_t>(batch_id, group_id) +
oh_idx * oh_block * ow * pack_c;
const int bias_offset = oc_idx;
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + bias_offset;
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
Op op(scale_bias, scale_dst);
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op);
}
} // namespace
bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
int ic = fm.icpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_assert(0);
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;
int oh = param.osz[0];
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int stride_h = param.filter_meta.stride[0];
int oh_block = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
auto do_trans_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_weight_trans(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_trans_weight, {1}});
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
int stride, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);
UNROLL_CALL_RAW(4, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);
UNROLL_CALL_RAW(4, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
int stride, typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, stride, T, T2,
T3, int>::impl(c, src, weight);
};
//! OCHelper is used to trans oc_block to row number of result regs
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
#if MEGDNN_AARCH64
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
#endif
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
* */
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 2 * iw, stride);
load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 5 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 7 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 3;
constexpr int src_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <int stride>
void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin,
const int, const int pw, const int,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride, int8_t*) {
constexpr int ic_step = 1;
rep_step(ic_idx, ic, ic_step) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom));
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memcpy(sptr_base + pw * ic_step, sptr,
sizeof(int8_t) * iw * ic_step);
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
sptr_base += iw2 * pad_bottom * ic_step;
}
}
template <>
void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
const int8_t* sptr_origin, const int,
const int pw, const int, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);
constexpr int iw_step = 16;
constexpr int pack_iw_len = 4;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 4);
src[2] = vld1q_s8(temp_ptr + iw_idx + 8);
src[3] = vld1q_s8(temp_ptr + iw_idx + 12);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
*(sptr_base + iw_idx * pack_iw_len + 0) =
*(temp_ptr + iw_idx + 0);
*(sptr_base + iw_idx * pack_iw_len + 1) =
*(temp_ptr + iw_idx + 1);
*(sptr_base + iw_idx * pack_iw_len + 2) =
*(temp_ptr + iw_idx + 2);
*(sptr_base + iw_idx * pack_iw_len + 3) =
*(temp_ptr + iw_idx + 3);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}
static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr,
const int8_t* src_ptr,
const int oc, const int ic,
const int fh, const int fw,
const int fw2) {
constexpr int oc_step = 4;
const int fw_remain = fw2 - fw;
const int dst_ic_stride = fh * fw2;
const int oc_step_stride = fh * fw2 * ic * oc_step;
static const uint8_t transpose_4x4_idx[16] = {0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]);
rep_step(oc_idx, oc, oc_step) {
int32_t* dst_temp_ptr =
reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2);
const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>(
src_ptr + oc_idx * ic * fh * fw);
// transpose ic and pad
rep(fh_idx, fh) {
rep(fw_idx, fw) {
rep(ic_idx, ic) {
*(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr;
src_temp_ptr++;
}
dst_temp_ptr++;
}
rep(ic_idx, ic) {
memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0,
sizeof(int8_t) * oc_step * fw_remain);
}
dst_temp_ptr += fw_remain;
}
// transpose fw oc
int8_t* trans_dst_temp_ptr =
reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2);
rep_step(idx, oc_step_stride, 16) {
int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx);
vst1q_s8(trans_dst_temp_ptr + idx,
vqtbl1q_s8(temp, tbl_transpose_4x4));
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_int8_nchw_nchw44_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;
const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
} // namespace
#endif
// vim: syntax=cpp.doxygen
......@@ -176,187 +176,202 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_remain, typename Op, typename T>
template <int c_dim, int ow_remain, typename Op, typename T, typename T2,
typename T3>
struct StoreOcxOw8Remain {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc);
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 0, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 8, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 7, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op(c[0][6], dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op(c[1][6], dst_ptr + ld_dst_oc + 24);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 6, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 5, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op(c[0][4], dst_ptr + 16);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op(c[1][4], dst_ptr + ld_dst_oc + 16);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 4, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 3, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op(c[0][2], dst_ptr + 8);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op(c[1][2], dst_ptr + ld_dst_oc + 8);
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 2, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 1, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op(c[0][0], dst_ptr);
op(c[1][0], dst_ptr + ld_dst_oc);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 0, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 8, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 7, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op(c[0][6], dst_ptr + 24);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 6, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 5, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op(c[0][4], dst_ptr + 16);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 4, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 3, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op(c[0][2], dst_ptr + 8);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 2, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 1, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op(c[0][0], dst_ptr);
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> {
static void impl(T& c, const Op& op, T2 dst_ptr, int) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
}
};
template <int c_dim, int ow_remain, typename Op, typename T>
inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr,
template <int c_dim, int ow_remain, typename Op, typename T, typename T2>
inline void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr,
ld_dst_oc);
}
template <int c_dim, int ow_remain, typename Op, typename T3, typename T,
typename T2>
inline void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr,
ld_dst_oc);
}
////////////////////Store_OC8_OW8_Remain/////////////////////////
......@@ -522,68 +537,84 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
}
}
/////////////////////////init_ocx_ow8////////////////////
inline float32x4_t neon_vdupq_n(float val) {
return vdupq_n_f32(val);
}
inline int32x4_t neon_vdupq_n(int val) {
return vdupq_n_s32(val);
}
inline float32x4_t neon_vld1q(const float* ptr) {
return vld1q_f32(ptr);
}
inline int32x4_t neon_vld1q(const int* ptr) {
return vld1q_s32(ptr);
}
template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2>
struct InitOcxOw8 {
static void impl(T& c, T2 bias_ptr, int oc_step);
static void impl(T& c, const T2* bias_ptr, int oc_step);
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> {
static void impl(T& c, const float32_t*, int) {
#define BAIS_INIT(step) \
c[0][step] = vdupq_n_f32(0); \
c[1][step] = vdupq_n_f32(0);
static void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) \
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \
c[1][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> {
static void impl(T& c, const float32_t*, int) {
#define BAIS_INIT(step) \
c[0][step] = vdupq_n_f32(0); \
c[1][step] = vdupq_n_f32(0);
static void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) \
c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \
c[1][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr); \
c[1][step] = vld1q_f32(bias_ptr + oc_step);
static void impl(T& c, const T2* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr); \
c[1][step] = neon_vld1q(bias_ptr + oc_step);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr); \
c[1][step] = vld1q_f32(bias_ptr + oc_step);
static void impl(T& c, const T2* bias_ptr, int oc_step) {
#define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr); \
c[1][step] = neon_vld1q(bias_ptr + oc_step);
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int oc_step) {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
constexpr int simd_len = 4;
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \
c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len);
#define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \
c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int oc_step) {
static void impl(T& c, const T2* bias_ptr, int oc_step) {
constexpr int simd_len = 4;
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \
c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len);
#define BAIS_INIT(step) \
c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \
c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len);
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
......@@ -591,57 +622,57 @@ struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> {
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> {
static void impl(T& c, const float32_t*, int) {
#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0);
static void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> {
static void impl(T& c, const float32_t*, int) {
#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0);
static void impl(T& c, const T2*, int) {
#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0));
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr);
static void impl(T& c, const T2* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr);
static void impl(T& c, const T2* bias_ptr, int) {
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr);
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int) {
static void impl(T& c, const T2* bias_ptr, int) {
constexpr int simd_len = 4;
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len);
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
};
template <typename T, typename T2>
struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int) {
static void impl(T& c, const T2* bias_ptr, int) {
constexpr int simd_len = 4;
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len);
#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len);
UNROLL_CALL_RAW(4, BAIS_INIT);
#undef BAIS_INIT
}
};
template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2>
inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) {
inline void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) {
InitOcxOw8<c_dim, bias_mode, ow_block, T, T2>::impl(c, bias_ptr, oc_step);
}
/////////////////////init_ocx_ow4/////////////////////
......
......@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44;
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
......@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&ds8_direct_stride1_large_group);
direct_algos.emplace_back(&ds8_direct_stride1_small_group);
direct_algos.emplace_back(&ds8_direct_stride2_large_group);
......
......@@ -62,6 +62,7 @@ private:
class AlgoFP16WinogradF23_8x8;
#endif
#if __ARM_FEATURE_DOTPROD
class AlgoDotS8DirectNCHWNCHW44;
class AlgoDotS8DirectStride1;
class AlgoDotS8DirectStride2;
class AlgoDotU8DirectStride1;
......
......@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 {
return vfmaq_laneq_f32(a, b, v, lane);
}
};
#if __ARM_FEATURE_DOTPROD
struct Vdotq_laneq_s32 {
template <const int lane>
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_laneq_s32(a, b, v, lane);
}
};
#endif
} // namespace
} // namespace megdnn
......
......@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb);
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace {
template <int lane>
struct Vfmap_laneq_f32_armv7 {
struct Vfmaq_laneq_f32_armv7 {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
};
template <>
struct Vfmap_laneq_f32_armv7<0> {
struct Vfmaq_laneq_f32_armv7<0> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 0);
}
};
template <>
struct Vfmap_laneq_f32_armv7<1> {
struct Vfmaq_laneq_f32_armv7<1> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 1);
}
};
template <>
struct Vfmap_laneq_f32_armv7<2> {
struct Vfmaq_laneq_f32_armv7<2> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 0);
}
};
template <>
struct Vfmap_laneq_f32_armv7<3> {
struct Vfmaq_laneq_f32_armv7<3> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1);
}
};
} // namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v)
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
template <int lane>
struct Vdotq_laneq_s32_armv7 {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v);
};
template <>
struct Vdotq_laneq_s32_armv7<0> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 0);
}
};
template <>
struct Vdotq_laneq_s32_armv7<1> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 1);
}
};
template <>
struct Vdotq_laneq_s32_armv7<2> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_s32(v), 0);
}
};
template <>
struct Vdotq_laneq_s32_armv7<3> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_f32(v), 1);
}
};
#define vdotq_laneq_s32(a, b, v, lane) \
Vdotq_laneq_s32_armv7<lane>::impl(a, b, v)
#endif
#endif
......
......@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
.set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false);
benchmarker_int.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"));
conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+"));
Benchmarker<ConvBias> benchmarker_float(handle);
benchmarker_float.set_display(false).set_times(RUNS);
benchmarker_float.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"));
conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+"));
Benchmarker<ConvBias> benchmarker_nchw44(handle);
if (is_fp32) {
......@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
run(1, 256, 256, 14, 14, 3, 1, false);
run(1, 512, 512, 7, 7, 3, 1, false);
} else {
run(1, 1, 4, 112, 112, 2, 2, true);
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 32, 224, 224, 5, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
run(1, 1, 4, 112, 112, 2, 1, true);
run(1, 3, 32, 224, 224, 3, 1, true);
run(1, 3, 32, 224, 224, 5, 1, true);
run(1, 3, 64, 224, 224, 7, 1, true);
for (size_t stride : {1, 2}) {
printf("stride %zu\n", stride);
for (size_t filter_size : {2, 3, 5, 7}) {
......@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), true);
benchmark_convbias(handle(), false);
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), true);
benchmark_convbias(handle(), false);
}
#endif
......
......@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册