diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h index db6bbf65e69963693c2581fa4b4229c06892670b..4715d80956eeadfeacb2ddd79ac572d364639c8e 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h @@ -302,6 +302,470 @@ struct KerNeonXXs2NchwNchw44 { store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); } }; +#if MEGDNN_AARCH64 +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + uint8x16_t vtbl = vld1q_u8(src_idx_buffer); + + // constexpr int stride = 2; + constexpr int oc_block = 8; + constexpr int remain_w = 0; + constexpr int filter_size = 7; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int pack_iw_len = 4; + constexpr int fh_step = 2; + constexpr int c_dim = OCHelper::val; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic; + const size_t src_step = fh_step * iw * ic_step * pack_iw_len; + const size_t weight_step = filter_size * pack_iw_len * fh_step; + const size_t weight_step_small = filter_size * pack_iw_len; + int32x4_t c[c_dim][4]; + + init_ocx_ow4(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + + const int8_t* weight_ptr_oc = weight_ptr + ld_dot4_weight_oc; + + const int8_t* nchw_src_ptr_last_line = + src_ptr + ic_idx * ic_stride + + 6 * iw * ic_step * pack_iw_len; + /** + * r0-r7 c + * r24-r31 temp + * r8-r15 src + * r16-r22 weight + * r23 vtbl + */ + asm volatile( + + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "smull v24.8h, v8.8b, v16.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], " + "%[weight_step]\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "sadalp %[c13].4s, v31.8h\n" + //! fh = 2 + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "add %[nchw_src_ptr], %[nchw_src_ptr], %[src_step]\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], " + "%[weight_step]\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q16, q17, [%[weight_ptr]]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "ldp q10, q11, [%[nchw_src_ptr], #32]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldp q19, q20, [%[weight_ptr_oc]]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "ldp q12, q13, [%[nchw_src_ptr], #64]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "ldr q18, [%[weight_ptr],#32]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "ldr q21, [%[weight_ptr_oc],#32]\n" + "sadalp %[c13].4s, v31.8h\n" + //! fh = 4 + "smull v28.8h, v8.8b, v19.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v8.16b, v19.16b\n" + "ldr d8, [%[nchw_src_ptr],#48]\n" + "smlal2 v29.8h, v9.16b, v19.16b\n" + "smlal2 v30.8h, v10.16b, v19.16b\n" + "smlal2 v31.8h, v11.16b, v19.16b\n" + "smull v24.8h, v9.8b, v17.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v10.8b, v17.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v11.8b, v17.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v12.8b, v17.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v9.16b, v17.16b\n" + "smlal2 v25.8h, v10.16b, v17.16b\n" + "smlal2 v26.8h, v11.16b, v17.16b\n" + "smlal2 v27.8h, v12.16b, v17.16b\n" + "smull v28.8h, v9.8b, v20.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v10.8b, v20.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v11.8b, v20.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v12.8b, v20.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v9.16b, v20.16b\n" + "ldr d9, [%[nchw_src_ptr],#64]\n" + "smlal2 v29.8h, v10.16b, v20.16b\n" + "ldr d14, [%[nchw_src_ptr],#80]\n" + "smlal2 v30.8h, v11.16b, v20.16b\n" + "smlal2 v31.8h, v12.16b, v20.16b\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal2 v24.8h, v10.16b, v18.16b\n" + "ldr d19, [%[weight_ptr_oc],#48]\n" + "smlal2 v25.8h, v11.16b, v18.16b\n" + "ldr d15, [%[nchw_src_ptr],#96]\n" + "smlal2 v26.8h, v12.16b, v18.16b\n" + "smlal2 v27.8h, v13.16b, v18.16b\n" + "ldr d18, [%[weight_ptr],#48]\n" + "smull v28.8h, v10.8b, v21.8b\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v11.8b, v21.8b\n" + "add %[weight_ptr_oc], %[weight_ptr_oc], %[weight_step]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "ldr q16, [%[weight_ptr]]\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal2 v28.8h, v10.16b, v21.16b\n" + "smlal2 v29.8h, v11.16b, v21.16b\n" + "ldp q10, q11, [%[nchw_src_ptr_last_line], #32]\n" + "smlal2 v30.8h, v12.16b, v21.16b\n" + "smlal2 v31.8h, v13.16b, v21.16b\n" + "ldp q12, q13, [%[nchw_src_ptr_last_line], #64]\n" + "smull v24.8h, v8.8b, v18.8b\n" + "ldr d21, [%[weight_ptr_oc],#16]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v25.8h, v9.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v27.8h, v15.8b, v18.8b\n" + "ldr d18, [%[weight_ptr],#16]\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "ldp q8, q9, [%[nchw_src_ptr_last_line]]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ldr q19, [%[weight_ptr_oc]]\n" + "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" + "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" + "sadalp %[c03].4s, v27.8h\n" + "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" + "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" + "sadalp %[c10].4s, v28.8h\n" + "tbl v12.16b, {v12.16b}, %[vtbl].16b\n" + "tbl v13.16b, {v13.16b}, %[vtbl].16b\n" + "sadalp %[c11].4s, v29.8h\n" + /// last line//// + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "smull v27.8h, v11.8b, v16.8b\n" + "smlal2 v24.8h, v9.16b, v16.16b\n" + "smlal2 v25.8h, v10.16b, v16.16b\n" + "smlal2 v26.8h, v11.16b, v16.16b\n" + "smlal2 v27.8h, v12.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v8.8b, v19.8b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v9.8b, v19.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v10.8b, v19.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v11.8b, v19.8b\n" + "smlal2 v28.8h, v9.16b, v19.16b\n" + "dup v9.8b, v11.b[0]\n" + "smlal2 v29.8h, v10.16b, v19.16b\n" + "smlal2 v30.8h, v11.16b, v19.16b\n" + "smlal2 v31.8h, v12.16b, v19.16b\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v24.8h, v10.8b, v18.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v11.8b, v18.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v12.8b, v18.8b\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v27.8h, v13.8b, v18.8b\n" + "add x10, %[nchw_src_ptr_last_line], #96\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v10.8b, v21.8b\n" + + "sadalp %[c01].4s, v25.8h\n" + "add x5, %[weight_ptr], #24\n" + "smull v29.8h, v11.8b, v21.8b\n" + "add x6, %[weight_ptr_oc], #24\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v12.8b, v21.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v13.8b, v21.8b\n" + "dup v10.8b, v12.b[0]\n" + "sadalp %[c10].4s, v28.8h\n" + "ld1r {v12.8b}, [x10]\n" + "sadalp %[c11].4s, v29.8h\n" + "dup v11.8b, v13.b[0]\n" + "sadalp %[c12].4s, v30.8h\n" + "ld1r {v16.2s}, [x5]\n" + "sadalp %[c13].4s, v31.8h\n" + "sxtl v16.8h, v16.8b\n" + ///////////////last element///////// + "add %[weight_ptr], %[weight_ptr], %[weight_step_small]\n" + "sxtl v9.8h, v9.8b\n" + "ld1r {v19.2s}, [x6]\n" + "sxtl v10.8h, v10.8b\n" + "sxtl v11.8h, v11.8b\n" + "smlal %[c00].4s, v9.4h, v16.4h\n" + "sxtl v12.8h, v12.8b\n" + "smlal %[c01].4s, v10.4h, v16.4h\n" + "sxtl v19.8h, v19.8b\n" + "smlal %[c02].4s, v11.4h, v16.4h\n" + "smlal %[c03].4s, v12.4h, v16.4h\n" + "smlal %[c10].4s, v9.4h, v19.4h\n" + "smlal %[c11].4s, v10.4h, v19.4h\n" + "smlal %[c12].4s, v11.4h, v19.4h\n" + "smlal %[c13].4s, v12.4h, v19.4h\n" + : + + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), + [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), + [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), + [nchw_src_ptr] "+r"(nchw_src_ptr), + [weight_ptr] "+r"(weight_ptr), + [weight_ptr_oc] "+r"(weight_ptr_oc) + + : [vtbl] "w"(vtbl), + [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), + [src_step] "r"(src_step), [weight_step] "r"(weight_step), + [weight_step_small] "r"(weight_step_small) + : "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +#endif template struct KerNeonXXs2NchwNchw44 { @@ -467,6 +931,166 @@ struct KerNeonXXs2NchwNchw44 { store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); } }; +#if MEGDNN_AARCH64 +template +struct KerNeonXXs2NchwNchw44 { + static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op& op) { + constexpr int filter_size = 3; + static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + constexpr int oc_block = 8; + constexpr int remain_w = 0; + + constexpr int oc_step = 4; + constexpr int ic_step = 1; + constexpr int loop_ic_step = 1; + constexpr int pack_iw_len = 4; + + const int ic_stride = ih * iw * pack_iw_len; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const size_t weight_step = filter_size * filter_size * pack_iw_len; + constexpr int c_dim = OCHelper::val; + + int32x4_t c[c_dim][4]; + init_ocx_ow4(c, bias_ptr, oc_step); + uint8x16_t vtbl = vld1q_u8(src_idx_buffer); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; + const int8_t* nchw_src_ptr_last_line = + src_ptr + ic_idx * ic_stride + + 2 * iw * ic_step * pack_iw_len; + const int8_t* weight_ptr_oc = weight_ptr + ld_weight_oc; + /** + * r0-r7 c + * r24-r31 temp + * r8-r15 src + * r16-r19 weight + * r20-vtbl + */ + asm volatile( + //! load src 0,1 + "ldp q8,q9, [%[nchw_src_ptr]]\n" + "ldr q16, [%[weight_ptr]]\n" + "ldp q10,q11, [%[nchw_src_ptr], #32]\n" + "add x5, %[weight_ptr], #32\n" + "smull v24.8h, v8.8b, v16.8b\n" + "ldr q17, [%[weight_ptr_oc]]\n" + "smull v25.8h, v9.8b, v16.8b\n" + "add x6, %[weight_ptr_oc], #32\n" + "smull v26.8h, v10.8b, v16.8b\n" + "smull v27.8h, v11.8b, v16.8b\n" + "smlal2 v24.8h, v8.16b, v16.16b\n" + "add x7, %[nchw_src_ptr_last_line], #64\n" + "smlal2 v25.8h, v9.16b, v16.16b\n" + "smlal2 v26.8h, v10.16b, v16.16b\n" + "smlal2 v27.8h, v11.16b, v16.16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v8.8b, v17.8b\n" + "ldr d12, [%[nchw_src_ptr],#16]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v9.8b, v17.8b\n" + "ldr d13, [%[nchw_src_ptr],#32]\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v10.8b, v17.8b\n" + "ldr d14, [%[nchw_src_ptr],#48]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v11.8b, v17.8b\n" + "ldr d18, [%[weight_ptr],#16]\n" + "smlal2 v28.8h, v8.16b, v17.16b\n" + "ldr d19, [%[weight_ptr_oc],#16]\n" + "smlal2 v29.8h, v9.16b, v17.16b\n" + "ldr d15, [%[nchw_src_ptr],#64]\n" + "smlal2 v30.8h, v10.16b, v17.16b\n" + "ldp q8,q9, [%[nchw_src_ptr_last_line]]\n" + "smull v24.8h, v12.8b, v18.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smlal2 v31.8h, v11.16b, v17.16b\n" + "ldp q10,q11, [%[nchw_src_ptr_last_line], #32]\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v13.8b, v18.8b\n" + "tbl v8.16b, {v8.16b}, %[vtbl].16b\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v14.8b, v18.8b\n" + "ldr d16, [%[weight_ptr],#24]\n" + "sadalp %[c13].4s, v31.8h\n" + "ldr d17, [%[weight_ptr_oc],#24]\n" + "smull v27.8h, v15.8b, v18.8b\n" + "tbl v9.16b, {v9.16b}, %[vtbl].16b\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v28.8h, v12.8b, v19.8b\n" + "tbl v10.16b, {v10.16b}, %[vtbl].16b\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v29.8h, v13.8b, v19.8b\n" + "tbl v11.16b, {v11.16b}, %[vtbl].16b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v30.8h, v14.8b, v19.8b\n" + "ld1r {v18.2s}, [x5]\n" + "sadalp %[c03].4s, v27.8h\n" + "smull v31.8h, v15.8b, v19.8b\n" + "ld1r {v19.2s}, [x6]\n" + "sadalp %[c10].4s, v28.8h\n" + "smull v24.8h, v8.8b, v16.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smull v25.8h, v9.8b, v16.8b\n" + "dup v12.8b, v9.b[0]\n" + "sadalp %[c12].4s, v30.8h\n" + "smull v26.8h, v10.8b, v16.8b\n" + "dup v12.8b, v9.b[0]\n" + "sadalp %[c13].4s, v31.8h\n" + "smull v27.8h, v11.8b, v16.8b\n" + "dup v13.8b, v10.b[0]\n" + "smull v28.8h, v8.8b, v17.8b\n" + "dup v14.8b, v11.b[0]\n" + "sadalp %[c00].4s, v24.8h\n" + "smull v29.8h, v9.8b, v17.8b\n" + "ld1r {v15.8b}, [x7]\n" + "sadalp %[c01].4s, v25.8h\n" + "smull v30.8h, v10.8b, v17.8b\n" + "sxtl v12.8h, v12.8b\n" + "sxtl v18.8h, v18.8b\n" + "sadalp %[c02].4s, v26.8h\n" + "smull v31.8h, v11.8b, v17.8b\n" + "sxtl v13.8h, v13.8b\n" + "sadalp %[c03].4s, v27.8h\n" + "smlal %[c00].4s, v12.4h, v18.4h\n" + "sxtl v14.8h, v14.8b\n" + "sadalp %[c10].4s, v28.8h\n" + "smlal %[c01].4s, v13.4h, v18.4h\n" + "sxtl v15.8h, v15.8b\n" + "sadalp %[c11].4s, v29.8h\n" + "smlal %[c02].4s, v14.4h, v18.4h\n" + "sxtl v19.8h, v19.8b\n" + "sadalp %[c12].4s, v30.8h\n" + "add %[weight_ptr], %[weight_ptr], %[weight_step]\n" + "smlal %[c03].4s, v15.4h, v18.4h\n" + "sadalp %[c13].4s, v31.8h\n" + "smlal %[c10].4s, v12.4h, v19.4h\n" + "smlal %[c11].4s, v13.4h, v19.4h\n" + "smlal %[c12].4s, v14.4h, v19.4h\n" + "smlal %[c13].4s, v15.4h, v19.4h\n" + : + + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), + [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), + [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), + + [weight_ptr] "+r"(weight_ptr), + [weight_ptr_oc] "+r"(weight_ptr_oc) + : [vtbl] "w"(vtbl), [nchw_src_ptr] "r"(nchw_src_ptr), + [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), + [weight_step] "r"(weight_step) + : "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + } + store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); + } +}; +#endif + template struct KerNeonXXs2NchwNchw44 {