提交 3597a6db 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dnn/arm): nchw_nchw44 conv support 1x1s1

GitOrigin-RevId: 8c8f7d7c763b603961ca27b1fd17425bafd019cd
上级 c64b1c94
......@@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block,
1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 1;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;
const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
......@@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
......
......@@ -58,6 +58,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block,
2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int,
int, int, int, const Op&) {
megdnn_assert(0, "not impl");
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
......@@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
......
......@@ -112,6 +112,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 1;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
......@@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
......
......@@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> {
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, stride> {
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int,
int, int, int, const Op&) {
megdnn_assert(0, "not impl nchw_nchw44 1x1 s2");
}
};
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 };
template <PACK_MODE mode>
MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr,
......@@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
......
......@@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
......
......@@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
......
......@@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>(
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_filter =
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
fm.spatial[0] == 7 ||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
......@@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_filter =
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
fm.spatial[0] == 7 ||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
......
......@@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
handle(), "S8_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1,
false, true),
handle(), "S8_CONV_NCHW_NCHW44");
}
/*****************************quint8 direct****************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
......@@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE,
1, false, true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
......
......@@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true);
bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true);
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1);
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1);
......
......@@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if (param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT) {
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
if (filter_shape[1] == 1 && filter_shape[2] == 1 &&
filter_shape.ndim == 6) {
group *= 4;
}
size_t computation = dst_shape.total_nr_elems() * fh * fw *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册