diff --git a/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp b/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp index ba09071d23c3308b74c35c7939f95f427a1cb8ad..18441065001c4e29e850418a2e98accff305e9aa 100644 --- a/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp +++ b/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp @@ -21,9 +21,13 @@ public: DirectConvRunner(size_t flt_size, size_t stride) { if (flt_size == 9 && stride == 1) { m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; - } else { - megdnn_assert(flt_size == 9 && stride == 2); + } else if (flt_size == 9 && stride == 2) { m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; + } else if (flt_size == 11 && stride == 1) { + m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16; + } else { + megdnn_assert(flt_size == 11 && stride == 2); + m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16; } } size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { @@ -208,8 +212,8 @@ bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( (bias_mode == BiasMode::NO_BIAS || bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && - SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && - fm.ocpg == 1; + SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9 || FH == 11) && + fm.icpg == 1 && fm.ocpg == 1; return avaible; } diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h index dd8122889beac8a5043e101e768dff679e5bf2f6..c689ae743dd6dc00a9aff79147d711990b3f2dfe 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h @@ -12,4 +12,13 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, int8_t relu_val); +void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val); + +void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val); #endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2e2829b023827f6228a720bb2d5b3cfbac8c991 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp @@ -0,0 +1,240 @@ +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" +#include "src/common/unroll_macro.h" + +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val) { + //! 4x16 + const size_t SH = 1; + const size_t SW = 1; + + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, + 2, 3, 4, 5, 3, 4, 5, 6}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, + 6, 7, 8, 9, 7, 8, 9, 10}; + static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); + + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + + //! init + int32x4_t c[4][4]; +#define cb(step) \ + c[step][0] = vdupq_n_s32(bias); \ + c[step][1] = vdupq_n_s32(bias); \ + c[step][2] = vdupq_n_s32(bias); \ + c[step][3] = vdupq_n_s32(bias); + + UNROLL_CALL_RAW(4, cb); +#undef cb +#define flt_reg 4 + int8x16_t flt[flt_reg]; + flt[0] = vld1q_s8(weight + 0 * 16); + flt[1] = vld1q_s8(weight + 1 * 16); + flt[2] = vld1q_s8(weight + 2 * 16); + flt[3] = vld1q_s8(weight + 3 * 16); + //! row 0 + + int8x16_t read_w[2]; + read_w[0] = vld1q_s8(src_n + 0 * pad_iw); + read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); + int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); + int8x16_t n0123_1 = n4567_0; + int8x16_t n4567_1 = n89ab_0; + int8x16_t n89ab_1 = ncdef_0; + int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); + int8x16_t n0123_2 = n89ab_0; + int8x16_t n4567_2 = ncdef_0; + int8x16_t n89ab_2 = ncdef_1; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); +#define CAL_C(oh, flt_start) \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % flt_reg], \ + (flt_start + 0) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % flt_reg], \ + (flt_start + 0) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % flt_reg], \ + (flt_start + 0) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % flt_reg], \ + (flt_start + 0) % 4); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % flt_reg], \ + (flt_start + 1) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % flt_reg], \ + (flt_start + 1) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % flt_reg], \ + (flt_start + 1) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % flt_reg], \ + (flt_start + 1) % 4); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % flt_reg], \ + (flt_start + 2) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % flt_reg], \ + (flt_start + 2) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % flt_reg], \ + (flt_start + 2) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % flt_reg], \ + (flt_start + 2) % 4); + + CAL_C(0, 0); + + //! row 1 +#define LOAD_SRC(row_id) \ + read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ + read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ + n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ + n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ + n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ + ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ + n0123_1 = n4567_0; \ + n4567_1 = n89ab_0; \ + n89ab_1 = ncdef_0; \ + ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ + n0123_2 = n89ab_0; \ + n4567_2 = ncdef_0; \ + n89ab_2 = ncdef_1; \ + ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); + + LOAD_SRC(1); + CAL_C(0, 3); + CAL_C(1, 0); + + //! row 2 + LOAD_SRC(2); + CAL_C(0, 3 * 2); + CAL_C(1, 3 * 1); + CAL_C(2, 3 * 0); + + //! row 3 + LOAD_SRC(3); + CAL_C(0, 3 * 3); + CAL_C(1, 3 * 2); + CAL_C(2, 3 * 1); + CAL_C(3, 3 * 0); + + //! row 4 + LOAD_SRC(4); + CAL_C(0, 3 * 4); + CAL_C(1, 3 * 3); + CAL_C(2, 3 * 2); + CAL_C(3, 3 * 1); + //! update flt 4 -> 0 + flt[0] = vld1q_s8(weight + 4 * 16); + + //! row 5 + LOAD_SRC(5); + CAL_C(0, 3 * 5); + CAL_C(1, 3 * 4); + CAL_C(2, 3 * 3); + CAL_C(3, 3 * 2); + //! update flt 5 -> 1 + flt[1] = vld1q_s8(weight + 5 * 16); + + //! row 6 + LOAD_SRC(6); + CAL_C(0, 3 * 6); + CAL_C(1, 3 * 5); + CAL_C(2, 3 * 4); + CAL_C(3, 3 * 3); + //! update flt 6 -> 2 + flt[2] = vld1q_s8(weight + 6 * 16); + + //! row 7 + LOAD_SRC(7); + CAL_C(0, 3 * 7); + CAL_C(1, 3 * 6); + CAL_C(2, 3 * 5); + CAL_C(3, 3 * 4); + + //! row 8 + LOAD_SRC(8); + CAL_C(3, 3 * 5); + //! update flt 7 -> 3 + flt[3] = vld1q_s8(weight + 7 * 16); + CAL_C(2, 3 * 6); + CAL_C(1, 3 * 7); + CAL_C(0, 3 * 8); + + //! row 9 + LOAD_SRC(9); + CAL_C(0, 3 * 9); + CAL_C(1, 3 * 8); + CAL_C(2, 3 * 7); + CAL_C(3, 3 * 6); + + //! row 10 + LOAD_SRC(10); + //! update flt 8 -> 0 + flt[0] = vld1q_s8(weight + 8 * 16); + CAL_C(3, 3 * 7); + CAL_C(2, 3 * 8); + CAL_C(1, 3 * 9); + CAL_C(0, 3 * 10); + + //! row 11 + LOAD_SRC(11); + CAL_C(1, 3 * 10); + CAL_C(2, 3 * 9); + CAL_C(3, 3 * 8); + + //! row 12 + LOAD_SRC(12); + CAL_C(2, 3 * 10); + CAL_C(3, 3 * 9); + + //! row 13 + LOAD_SRC(13); + CAL_C(3, 3 * 10); + + float32x4_t dst_reg[4][4]; +#define cb(step) \ + dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ + dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ + dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ + dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + +#define cb(step) \ + dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ + dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ + dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ + dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); + + UNROLL_CALL_RAW(4, cb); +#undef cb + int8_t* dst_store = dst + oh * OW + ow; + int8x16_t relu_reg = vdupq_n_s8(relu_val); +#define cb(step) \ + quant_store_s8( \ + dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ + dst_store + step * OW, relu_reg); + + UNROLL_CALL_RAW(4, cb); +#undef cb +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4032dc089b1432d533a2bf142b83027fcee3b189 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp @@ -0,0 +1,249 @@ +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" +#include "src/common/unroll_macro.h" + +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val) { + //! 4x16 + const size_t SH = 2; + const size_t SW = 2; + + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, + 4, 5, 6, 7, 6, 7, 8, 9}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, + 8, 9, 10, 11, 10, 11, 12, 13}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + + //! init + int32x4_t c[4][4]; +#define cb(step) \ + c[step][0] = vdupq_n_s32(bias); \ + c[step][1] = vdupq_n_s32(bias); \ + c[step][2] = vdupq_n_s32(bias); \ + c[step][3] = vdupq_n_s32(bias); + + UNROLL_CALL_RAW(4, cb); +#undef cb + +#define flt_reg 9 +#define flt_per_reg 4 + + int8x16_t flt[flt_reg]; +#define cb(step) flt[step] = vld1q_s8(weight + step * 16); + + UNROLL_CALL_RAW(flt_reg, cb); +#undef cb + +#define CAL_C(oh, flt_start) \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); + +#define LOAD_SRC(row_id) \ + read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ + read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ + read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ + ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ + ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ + n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ + n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ + n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ + ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ + n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ + n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ + n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ + ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ + n0123_2 = n4567_0; \ + n4567_2 = n89ab_0; \ + n89ab_2 = ncdef_0; \ + ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); + + //! row 0 + int8x16_t read_w[3]; + read_w[0] = vld1q_s8(src_n); + read_w[1] = vld1q_s8(src_n + 16); + read_w[2] = vld1q_s8(src_n + 32); + int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); + int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); + + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); + int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); + + int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); + int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); + int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); + + int8x16_t n0123_2 = n4567_0; + int8x16_t n4567_2 = n89ab_0; + int8x16_t n89ab_2 = ncdef_0; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); + + CAL_C(0, 0); + + //! row 1 + LOAD_SRC(1); + CAL_C(0, 3 * 1); + + //! row 2 + LOAD_SRC(2); + CAL_C(0, 3 * 2); + CAL_C(1, 3 * 0); + + //! row 3 + LOAD_SRC(3); + CAL_C(0, 3 * 3); + CAL_C(1, 3 * 1); + + //! row 4 + LOAD_SRC(4); + CAL_C(0, 3 * 4); + CAL_C(1, 3 * 2); + CAL_C(2, 3 * 0); + + //! row 5 + LOAD_SRC(5); + CAL_C(0, 3 * 5); + CAL_C(1, 3 * 3); + CAL_C(2, 3 * 1); + + //! row 6 + LOAD_SRC(6); + CAL_C(0, 3 * 6); + CAL_C(1, 3 * 4); + CAL_C(2, 3 * 2); + CAL_C(3, 3 * 0); + + //! row 7 + LOAD_SRC(7); + CAL_C(0, 3 * 7); + CAL_C(1, 3 * 5); + CAL_C(2, 3 * 3); + CAL_C(3, 3 * 1); + + //! row 8 + LOAD_SRC(8); + CAL_C(0, 3 * 8); + CAL_C(1, 3 * 6); + CAL_C(2, 3 * 4); + CAL_C(3, 3 * 2); + + //! row 9 + LOAD_SRC(9); + CAL_C(0, 3 * 9); + CAL_C(1, 3 * 7); + CAL_C(2, 3 * 5); + CAL_C(3, 3 * 3); + + //! row 10 + LOAD_SRC(10); + CAL_C(0, 3 * 10); + CAL_C(1, 3 * 8); + CAL_C(2, 3 * 6); + CAL_C(3, 3 * 4); + + //! row 11 + LOAD_SRC(11); + CAL_C(1, 3 * 9); + CAL_C(2, 3 * 7); + CAL_C(3, 3 * 5); + + //! row 12 + LOAD_SRC(12); + CAL_C(1, 3 * 10); + CAL_C(2, 3 * 8); + CAL_C(3, 3 * 6); + + //! row 13 + LOAD_SRC(13); + CAL_C(2, 3 * 9); + CAL_C(3, 3 * 7); + + //! row 14 + LOAD_SRC(14); + CAL_C(2, 3 * 10); + CAL_C(3, 3 * 8); + + //! row 15 + LOAD_SRC(15); + CAL_C(3, 3 * 9); + + //! row 16 + LOAD_SRC(16); + CAL_C(3, 3 * 10); + + float32x4_t dst_reg[4][4]; +#define cb(step) \ + dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ + dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ + dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ + dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + +#define cb(step) \ + dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ + dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ + dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ + dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); + + UNROLL_CALL_RAW(4, cb); +#undef cb + int8_t* dst_store = dst + oh * OW + ow; + int8x16_t relu_reg = vdupq_n_s8(relu_val); +#define cb(step) \ + quant_store_s8( \ + dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ + dst_store + step * OW, relu_reg); + + UNROLL_CALL_RAW(4, cb); +#undef cb +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp index f34013ef24d27710fa08e9aa9b9751022806c574..3a47cbc28f139ee300e2af3f4eab702ca416bb17 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp @@ -36,16 +36,14 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( UNROLL_CALL_RAW(4, cb); #undef cb - constexpr int flt_reg = 7; - constexpr int flt_per_reg = 4; - int8x16_t flt[7]; - flt[0] = vld1q_s8(weight + 0 * 16); - flt[1] = vld1q_s8(weight + 1 * 16); - flt[2] = vld1q_s8(weight + 2 * 16); - flt[3] = vld1q_s8(weight + 3 * 16); - flt[4] = vld1q_s8(weight + 4 * 16); - flt[5] = vld1q_s8(weight + 5 * 16); - flt[6] = vld1q_s8(weight + 6 * 16); +#define flt_reg 7 +#define flt_per_reg 4 + + int8x16_t flt[flt_reg]; +#define cb(step) flt[step] = vld1q_s8(weight + step * 16); + + UNROLL_CALL_RAW(flt_reg, cb); +#undef cb #define CAL_C(oh, flt_start) \ c[oh][0] = vdotq_laneq_s32( \ diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 6a89a537051e629fe4637b4bf9d4a163fa67dec9..b6d25176bc201a6b235624d1d9b1b058b2b70151 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -2060,6 +2060,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { benchmark1.set_display(false); benchmark1.set_times(RUN); + Benchmarker benchmark2(handle()); + benchmark2.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark2.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTS8")); + benchmark2.set_display(false); + benchmark2.set_times(RUN); + for (auto&& arg : args) { TensorLayout dst_layout; auto opr = handle()->create_operator(); @@ -2070,6 +2080,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { //! dst.nr_elems * FH * FW * 2 float computations = dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; + float computations_5x5 = dst_layout.total_nr_elems() * 5 * 5 * 2.0 / 1e6; + float computations_11x11 = dst_layout.total_nr_elems() * 11 * 11 * 2.0 / 1e6; + param::ConvBias param_5x5 = arg.param; + param_5x5.pad_h = param_5x5.pad_w = 5 / 2; + param::ConvBias param_11x11 = arg.param; + param_11x11.pad_h = param_11x11.pad_w = 11 / 2; auto used0 = benchmark0.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}) / @@ -2077,11 +2093,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { auto used1 = benchmark1.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}) / RUN; + TensorShape flt_5x5_shape = arg.filter; + flt_5x5_shape[3] = flt_5x5_shape[4] = 5; + + auto used5x5 = benchmark2.set_param(param_5x5).exec( + {arg.src, flt_5x5_shape, arg.bias, {}, {}}) / + RUN; + TensorShape flt_11x11_shape = arg.filter; + flt_11x11_shape[3] = flt_11x11_shape[4] = 11; + auto used11x11 = benchmark0.set_param(param_11x11) + .exec({arg.src, flt_11x11_shape, arg.bias, {}, {}}) / + RUN; - printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " - "speedup: %f\n", - arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, - computations / used0, used1, computations / used1, used1 / used0); + printf("%s %s s %u: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " + "speedup: %f, compare 5x5 %f ms %f GFlops speedup %f, compare 11x11 %f " + "ms %f GFops speedup %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + arg.param.stride_h, used0, computations / used0, used1, + computations / used1, used1 / used0, used5x5, computations_5x5 / used5x5, + used5x5 / used0, used11x11, computations_11x11 / used11x11, + used11x11 / used0); } } diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index f08bb174baf598e96f1cfd9aa1b8e2b7c9e41118..6ca56acf483d7f00f3ffae97be156ea5205374ea 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -612,13 +612,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { #if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { checker_conv_bias_qint8x8x8( - get_channel_wise_args({9}, 1, false, true, true, true), handle(), + get_channel_wise_args({9, 11}, 1, false, true, true, true), handle(), "ARMDOTS8_DIRECT_CHANWISE_LARGE"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { checker_conv_bias_qint8x8x8( - get_channel_wise_args({9}, 2, false, true, true, true), handle(), + get_channel_wise_args({9, 11}, 2, false, true, true, true), handle(), "ARMDOTS8_DIRECT_CHANWISE_LARGE"); }