提交 9e0583e1 编写于 作者: M Megvii Engine Team

feat(dnn/arm_common): add arm_common chanwise dot 11x11

GitOrigin-RevId: 84e0815a5943d2efcdcb79d32196e7a405e315b0
上级 115bcbce
...@@ -21,9 +21,13 @@ public: ...@@ -21,9 +21,13 @@ public:
DirectConvRunner(size_t flt_size, size_t stride) { DirectConvRunner(size_t flt_size, size_t stride) {
if (flt_size == 9 && stride == 1) { if (flt_size == 9 && stride == 1) {
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16;
} else { } else if (flt_size == 9 && stride == 2) {
megdnn_assert(flt_size == 9 && stride == 2);
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16;
} else if (flt_size == 11 && stride == 1) {
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16;
} else {
megdnn_assert(flt_size == 11 && stride == 2);
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16;
} }
} }
size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const {
...@@ -208,8 +212,8 @@ bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( ...@@ -208,8 +212,8 @@ bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable(
(bias_mode == BiasMode::NO_BIAS || (bias_mode == BiasMode::NO_BIAS ||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9 || FH == 11) &&
fm.ocpg == 1; fm.icpg == 1 && fm.ocpg == 1;
return avaible; return avaible;
} }
......
...@@ -12,4 +12,13 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( ...@@ -12,4 +12,13 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val); int8_t relu_val);
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val);
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val);
#endif #endif
\ No newline at end of file
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val) {
//! 4x16
const size_t SH = 1;
const size_t SW = 1;
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8,
6, 7, 8, 9, 7, 8, 9, 10};
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12,
10, 11, 12, 13, 11, 12, 13, 14};
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]);
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
//! init
int32x4_t c[4][4];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);
UNROLL_CALL_RAW(4, cb);
#undef cb
#define flt_reg 4
int8x16_t flt[flt_reg];
flt[0] = vld1q_s8(weight + 0 * 16);
flt[1] = vld1q_s8(weight + 1 * 16);
flt[2] = vld1q_s8(weight + 2 * 16);
flt[3] = vld1q_s8(weight + 3 * 16);
//! row 0
int8x16_t read_w[2];
read_w[0] = vld1q_s8(src_n + 0 * pad_iw);
read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16);
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2);
int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0);
int8x16_t n0123_1 = n4567_0;
int8x16_t n4567_1 = n89ab_0;
int8x16_t n89ab_1 = ncdef_0;
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0);
int8x16_t n0123_2 = n89ab_0;
int8x16_t n4567_2 = ncdef_0;
int8x16_t n89ab_2 = ncdef_1;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4);
CAL_C(0, 0);
//! row 1
#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \
n0123_1 = n4567_0; \
n4567_1 = n89ab_0; \
n89ab_1 = ncdef_0; \
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
n0123_2 = n89ab_0; \
n4567_2 = ncdef_0; \
n89ab_2 = ncdef_1; \
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);
LOAD_SRC(1);
CAL_C(0, 3);
CAL_C(1, 0);
//! row 2
LOAD_SRC(2);
CAL_C(0, 3 * 2);
CAL_C(1, 3 * 1);
CAL_C(2, 3 * 0);
//! row 3
LOAD_SRC(3);
CAL_C(0, 3 * 3);
CAL_C(1, 3 * 2);
CAL_C(2, 3 * 1);
CAL_C(3, 3 * 0);
//! row 4
LOAD_SRC(4);
CAL_C(0, 3 * 4);
CAL_C(1, 3 * 3);
CAL_C(2, 3 * 2);
CAL_C(3, 3 * 1);
//! update flt 4 -> 0
flt[0] = vld1q_s8(weight + 4 * 16);
//! row 5
LOAD_SRC(5);
CAL_C(0, 3 * 5);
CAL_C(1, 3 * 4);
CAL_C(2, 3 * 3);
CAL_C(3, 3 * 2);
//! update flt 5 -> 1
flt[1] = vld1q_s8(weight + 5 * 16);
//! row 6
LOAD_SRC(6);
CAL_C(0, 3 * 6);
CAL_C(1, 3 * 5);
CAL_C(2, 3 * 4);
CAL_C(3, 3 * 3);
//! update flt 6 -> 2
flt[2] = vld1q_s8(weight + 6 * 16);
//! row 7
LOAD_SRC(7);
CAL_C(0, 3 * 7);
CAL_C(1, 3 * 6);
CAL_C(2, 3 * 5);
CAL_C(3, 3 * 4);
//! row 8
LOAD_SRC(8);
CAL_C(3, 3 * 5);
//! update flt 7 -> 3
flt[3] = vld1q_s8(weight + 7 * 16);
CAL_C(2, 3 * 6);
CAL_C(1, 3 * 7);
CAL_C(0, 3 * 8);
//! row 9
LOAD_SRC(9);
CAL_C(0, 3 * 9);
CAL_C(1, 3 * 8);
CAL_C(2, 3 * 7);
CAL_C(3, 3 * 6);
//! row 10
LOAD_SRC(10);
//! update flt 8 -> 0
flt[0] = vld1q_s8(weight + 8 * 16);
CAL_C(3, 3 * 7);
CAL_C(2, 3 * 8);
CAL_C(1, 3 * 9);
CAL_C(0, 3 * 10);
//! row 11
LOAD_SRC(11);
CAL_C(1, 3 * 10);
CAL_C(2, 3 * 9);
CAL_C(3, 3 * 8);
//! row 12
LOAD_SRC(12);
CAL_C(2, 3 * 10);
CAL_C(3, 3 * 9);
//! row 13
LOAD_SRC(13);
CAL_C(3, 3 * 10);
float32x4_t dst_reg[4][4];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);
UNROLL_CALL_RAW(4, cb);
#undef cb
#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);
UNROLL_CALL_RAW(4, cb);
#undef cb
int8_t* dst_store = dst + oh * OW + ow;
int8x16_t relu_reg = vdupq_n_s8(relu_val);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);
UNROLL_CALL_RAW(4, cb);
#undef cb
}
#endif
\ No newline at end of file
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val) {
//! 4x16
const size_t SH = 2;
const size_t SW = 2;
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5,
4, 5, 6, 7, 6, 7, 8, 9};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9,
8, 9, 10, 11, 10, 11, 12, 13};
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
//! init
int32x4_t c[4][4];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);
UNROLL_CALL_RAW(4, cb);
#undef cb
#define flt_reg 9
#define flt_per_reg 4
int8x16_t flt[flt_reg];
#define cb(step) flt[step] = vld1q_s8(weight + step * 16);
UNROLL_CALL_RAW(flt_reg, cb);
#undef cb
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg);
#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \
n0123_2 = n4567_0; \
n4567_2 = n89ab_0; \
n89ab_2 = ncdef_0; \
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);
//! row 0
int8x16_t read_w[3];
read_w[0] = vld1q_s8(src_n);
read_w[1] = vld1q_s8(src_n + 16);
read_w[2] = vld1q_s8(src_n + 32);
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8);
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8);
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0);
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0);
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1);
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1);
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1);
int8x16_t n0123_2 = n4567_0;
int8x16_t n4567_2 = n89ab_0;
int8x16_t n89ab_2 = ncdef_0;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);
CAL_C(0, 0);
//! row 1
LOAD_SRC(1);
CAL_C(0, 3 * 1);
//! row 2
LOAD_SRC(2);
CAL_C(0, 3 * 2);
CAL_C(1, 3 * 0);
//! row 3
LOAD_SRC(3);
CAL_C(0, 3 * 3);
CAL_C(1, 3 * 1);
//! row 4
LOAD_SRC(4);
CAL_C(0, 3 * 4);
CAL_C(1, 3 * 2);
CAL_C(2, 3 * 0);
//! row 5
LOAD_SRC(5);
CAL_C(0, 3 * 5);
CAL_C(1, 3 * 3);
CAL_C(2, 3 * 1);
//! row 6
LOAD_SRC(6);
CAL_C(0, 3 * 6);
CAL_C(1, 3 * 4);
CAL_C(2, 3 * 2);
CAL_C(3, 3 * 0);
//! row 7
LOAD_SRC(7);
CAL_C(0, 3 * 7);
CAL_C(1, 3 * 5);
CAL_C(2, 3 * 3);
CAL_C(3, 3 * 1);
//! row 8
LOAD_SRC(8);
CAL_C(0, 3 * 8);
CAL_C(1, 3 * 6);
CAL_C(2, 3 * 4);
CAL_C(3, 3 * 2);
//! row 9
LOAD_SRC(9);
CAL_C(0, 3 * 9);
CAL_C(1, 3 * 7);
CAL_C(2, 3 * 5);
CAL_C(3, 3 * 3);
//! row 10
LOAD_SRC(10);
CAL_C(0, 3 * 10);
CAL_C(1, 3 * 8);
CAL_C(2, 3 * 6);
CAL_C(3, 3 * 4);
//! row 11
LOAD_SRC(11);
CAL_C(1, 3 * 9);
CAL_C(2, 3 * 7);
CAL_C(3, 3 * 5);
//! row 12
LOAD_SRC(12);
CAL_C(1, 3 * 10);
CAL_C(2, 3 * 8);
CAL_C(3, 3 * 6);
//! row 13
LOAD_SRC(13);
CAL_C(2, 3 * 9);
CAL_C(3, 3 * 7);
//! row 14
LOAD_SRC(14);
CAL_C(2, 3 * 10);
CAL_C(3, 3 * 8);
//! row 15
LOAD_SRC(15);
CAL_C(3, 3 * 9);
//! row 16
LOAD_SRC(16);
CAL_C(3, 3 * 10);
float32x4_t dst_reg[4][4];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);
UNROLL_CALL_RAW(4, cb);
#undef cb
#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);
UNROLL_CALL_RAW(4, cb);
#undef cb
int8_t* dst_store = dst + oh * OW + ow;
int8x16_t relu_reg = vdupq_n_s8(relu_val);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);
UNROLL_CALL_RAW(4, cb);
#undef cb
}
#endif
\ No newline at end of file
...@@ -36,16 +36,14 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( ...@@ -36,16 +36,14 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
UNROLL_CALL_RAW(4, cb); UNROLL_CALL_RAW(4, cb);
#undef cb #undef cb
constexpr int flt_reg = 7; #define flt_reg 7
constexpr int flt_per_reg = 4; #define flt_per_reg 4
int8x16_t flt[7];
flt[0] = vld1q_s8(weight + 0 * 16); int8x16_t flt[flt_reg];
flt[1] = vld1q_s8(weight + 1 * 16); #define cb(step) flt[step] = vld1q_s8(weight + step * 16);
flt[2] = vld1q_s8(weight + 2 * 16);
flt[3] = vld1q_s8(weight + 3 * 16); UNROLL_CALL_RAW(flt_reg, cb);
flt[4] = vld1q_s8(weight + 4 * 16); #undef cb
flt[5] = vld1q_s8(weight + 5 * 16);
flt[6] = vld1q_s8(weight + 6 * 16);
#define CAL_C(oh, flt_start) \ #define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \ c[oh][0] = vdotq_laneq_s32( \
......
...@@ -2060,6 +2060,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { ...@@ -2060,6 +2060,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
benchmark1.set_display(false); benchmark1.set_display(false);
benchmark1.set_times(RUN); benchmark1.set_times(RUN);
Benchmarker<ConvBias> benchmark2(handle());
benchmark2.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
benchmark2.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8"));
benchmark2.set_display(false);
benchmark2.set_times(RUN);
for (auto&& arg : args) { for (auto&& arg : args) {
TensorLayout dst_layout; TensorLayout dst_layout;
auto opr = handle()->create_operator<ConvBias>(); auto opr = handle()->create_operator<ConvBias>();
...@@ -2070,6 +2080,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { ...@@ -2070,6 +2080,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
//! dst.nr_elems * FH * FW * 2 //! dst.nr_elems * FH * FW * 2
float computations = float computations =
dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6;
float computations_5x5 = dst_layout.total_nr_elems() * 5 * 5 * 2.0 / 1e6;
float computations_11x11 = dst_layout.total_nr_elems() * 11 * 11 * 2.0 / 1e6;
param::ConvBias param_5x5 = arg.param;
param_5x5.pad_h = param_5x5.pad_w = 5 / 2;
param::ConvBias param_11x11 = arg.param;
param_11x11.pad_h = param_11x11.pad_w = 11 / 2;
auto used0 = benchmark0.set_param(arg.param).exec( auto used0 = benchmark0.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) / {arg.src, arg.filter, arg.bias, {}, {}}) /
...@@ -2077,11 +2093,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { ...@@ -2077,11 +2093,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
auto used1 = benchmark1.set_param(arg.param).exec( auto used1 = benchmark1.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) / {arg.src, arg.filter, arg.bias, {}, {}}) /
RUN; RUN;
TensorShape flt_5x5_shape = arg.filter;
flt_5x5_shape[3] = flt_5x5_shape[4] = 5;
auto used5x5 = benchmark2.set_param(param_5x5).exec(
{arg.src, flt_5x5_shape, arg.bias, {}, {}}) /
RUN;
TensorShape flt_11x11_shape = arg.filter;
flt_11x11_shape[3] = flt_11x11_shape[4] = 11;
auto used11x11 = benchmark0.set_param(param_11x11)
.exec({arg.src, flt_11x11_shape, arg.bias, {}, {}}) /
RUN;
printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " printf("%s %s s %u: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops "
"speedup: %f\n", "speedup: %f, compare 5x5 %f ms %f GFlops speedup %f, compare 11x11 %f "
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, "ms %f GFops speedup %f\n",
computations / used0, used1, computations / used1, used1 / used0); arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
arg.param.stride_h, used0, computations / used0, used1,
computations / used1, used1 / used0, used5x5, computations_5x5 / used5x5,
used5x5 / used0, used11x11, computations_11x11 / used11x11,
used11x11 / used0);
} }
} }
......
...@@ -612,13 +612,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { ...@@ -612,13 +612,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) {
checker_conv_bias_qint8x8x8( checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 1, false, true, true, true), handle(), get_channel_wise_args({9, 11}, 1, false, true, true, true), handle(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); "ARMDOTS8_DIRECT_CHANWISE_LARGE");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) {
checker_conv_bias_qint8x8x8( checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 2, false, true, true, true), handle(), get_channel_wise_args({9, 11}, 2, false, true, true, true), handle(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); "ARMDOTS8_DIRECT_CHANWISE_LARGE");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册