提交 4b9df8fb 编写于 作者: Y yiicy 提交者: Xiaoyang LI

improve dw conv performance

*  imporve prepack_input func speed in int8 3x3s1 dw conv

* fix code style

* fix code style

* improve 3x3s1 dw fp32 conv speed a little

* arm add 5x5s1 int8 dw conv, test=develop
上级 33c335a5
......@@ -145,14 +145,17 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21 = ptr_write;
outc31 = ptr_write;
}
auto c00 = outc00;
auto c01 = outc01;
auto c10 = outc10;
auto c11 = outc11;
auto c20 = outc20;
auto c21 = outc21;
auto c30 = outc30;
auto c31 = outc31;
float* outl[] = {outc00,
outc10,
outc20,
outc30,
outc01,
outc11,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(flag_relu)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
......@@ -210,6 +213,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/
"ldp x0, x1, [%[outl]] \n"
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/
......@@ -230,6 +234,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/
"ldp x2, x3, [%[outl], #16] \n"
"fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/
......@@ -239,6 +244,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/
"ldp x4, x5, [%[outl], #32] \n"
"fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/
......@@ -248,25 +254,83 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/
"ldp x6, x7, [%[outl], #48] \n"
"fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/
"fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/
"fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/
/* save result */
"stp q15, q16, [%[out]], #32\n"
"stp q17, q18, [%[out]], #32\n"
"stp q19, q20, [%[out]], #32\n"
"stp q21, q22, [%[out]]\n"
"fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */
"fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */
"fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */
"fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */
"fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */
"fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */
"fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */
"fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"cbnz %w[flag_mask], 1f\n"
"str q15, [x0]\n" /* save outc00 */
"str q16, [x4]\n" /* save outc01 */
"str q17, [x1]\n" /* save outc10 */
"str q18, [x5]\n" /* save outc11 */
"str q19, [x2]\n" /* save outc20 */
"str q20, [x6]\n" /* save outc21 */
"str q21, [x3]\n" /* save outc30 */
"str q22, [x7]\n" /* save outc31 */
"b 2f\n"
"1:\n"
"str q15, [%[out]], #16 \n" /* save remain to pre_out */
"str q17, [%[out]], #16 \n" /* save remain to pre_out */
"str q19, [%[out]], #16 \n" /* save remain to pre_out */
"str q21, [%[out]], #16 \n" /* save remain to pre_out */
"str q16, [%[out]], #16 \n" /* save remain to pre_out */
"str q18, [%[out]], #16 \n" /* save remain to pre_out */
"str q20, [%[out]], #16 \n" /* save remain to pre_out */
"str q22, [%[out]], #16 \n" /* save remain to pre_out */
"2:\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2), [inr3] "+r"(inr3),
[out]"+r"(out0)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8)
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[vbias]"w" (vbias), [outl] "r" (outl_ptr),
[flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v9", "v10", "v11", "v15",
"v16","v17","v18","v19","v20","v21","v22"
"v16","v17","v18","v19","v20","v21","v22",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"
);
#else
asm volatile(
......@@ -355,6 +419,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"vmla.f32 q8, q5, q2 @ w8 * inr22\n"
"vmla.f32 q9, q5, q3 @ w8 * inr23\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n"
"ldr r4, [%[outl], #32] @ load bias addr to r4\n"
"vmla.f32 q14, q6, q0 @ w5 * inr24\n"
"vmla.f32 q15, q6, q1 @ w5 * inr25\n"
"vmla.f32 q10, q5, q0 @ w8 * inr24\n"
......@@ -364,174 +429,103 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q7, q2 @ w6 * inr30\n"
"vmla.f32 q13, q7, q3 @ w6 * inr31\n"
"vst1.32 {d16-d19}, [%[out0]]! @ save r00, r01, c0~c3\n"
"vmla.f32 q14, q7, q0 @ w6 * inr32\n"
"vmla.f32 q15, q7, q1 @ w6 * inr33\n"
"vst1.32 {d20-d23}, [%[out0]]! @ save r02, r03, c0~c3\n"
"vmla.f32 q12, q4, q3 @ w7 * inr31\n"
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n"
"vld1.32 {d12-d13}, [r4] @ load bias\n"
"vmla.f32 q13, q4, q0 @ w7 * inr32\n"
"vmla.f32 q14, q4, q1 @ w7 * inr33\n"
"vmla.f32 q15, q4, q2 @ w7 * inr34\n"
"ldr r0, [%[outl]] @ load outc00 to r0\n"
"vmla.f32 q12, q5, q0 @ w8 * inr32\n"
"vmla.f32 q13, q5, q1 @ w8 * inr33\n"
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n"
"vmla.f32 q14, q5, q2 @ w8 * inr34\n"
"vmla.f32 q15, q5, q3 @ w8 * inr35\n"
"vst1.32 {d24-d27}, [%[out0]]! @ save r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[out0]]! @ save r12, r13, c0~c3\n"
"ldr r1, [%[outl], #4] @ load outc10 to r1\n"
"vadd.f32 q8, q8, q6 @ r00 add bias\n"
"vadd.f32 q9, q9, q6 @ r01 add bias\n"
"vadd.f32 q10, q10, q6 @ r02 add bias\n"
"vadd.f32 q11, q11, q6 @ r03 add bias\n"
"ldr r2, [%[outl], #8] @ load outc20 to r2\n"
"vadd.f32 q12, q12, q6 @ r10 add bias\n"
"vadd.f32 q13, q13, q6 @ r11 add bias\n"
"vadd.f32 q14, q14, q6 @ r12 add bias\n"
"vadd.f32 q15, q15, q6 @ r13 add bias\n"
"ldr r3, [%[outl], #12] @ load outc30 to r3\n"
"vmov.u32 q7, #0 @ mov zero to q7\n"
"cmp r5, #0 @ cmp flag relu\n"
"beq 1f @ skip relu\n"
"vmax.f32 q8, q8, q7 @ r00 relu\n"
"vmax.f32 q9, q9, q7 @ r01 relu\n"
"vmax.f32 q10, q10, q7 @ r02 relu\n"
"vmax.f32 q11, q11, q7 @ r03 relu\n"
"vmax.f32 q12, q12, q7 @ r10 relu\n"
"vmax.f32 q13, q13, q7 @ r11 relu\n"
"vmax.f32 q14, q14, q7 @ r12 relu\n"
"vmax.f32 q15, q15, q7 @ r13 relu\n"
"1:\n"
"ldr r4, [%[outl], #16] @ load outc01 to r4\n"
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n"
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n"
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n"
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n"
"ldr r5, [%[outl], #20] @ load outc11 to r5\n"
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n"
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n"
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n"
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n"
"cmp %[flag_mask], #0 @ cmp flag mask\n"
"bne 2f\n"
"vst1.32 {d16-d17}, [r0] @ save outc00\n"
"vst1.32 {d18-d19}, [r1] @ save outc10\n"
"vst1.32 {d20-d21}, [r2] @ save outc20\n"
"vst1.32 {d22-d23}, [r3] @ save outc30\n"
"vst1.32 {d24-d25}, [r4] @ save outc01\n"
"vst1.32 {d26-d27}, [r5] @ save outc11\n"
"ldr r0, [%[outl], #24] @ load outc21 to r0\n"
"ldr r1, [%[outl], #28] @ load outc31 to r1\n"
"vst1.32 {d28-d29}, [r0] @ save outc21\n"
"vst1.32 {d30-d31}, [r1] @ save outc31\n"
"b 3f @ branch end\n"
"2: \n"
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n"
"3: \n"
: [r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [r3] "+r"(inr3),
[out0] "+r"(out0), [wc0] "+r"(weight_c)
:
: [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr)
: "cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13","q14", "q15"
"q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5"
);
#endif // __arch64__
float* out1 = pre_out;
if (flag_mask) {
c00 = outc00;
c01 = outc01;
c10 = outc10;
c11 = outc11;
c20 = outc20;
c21 = outc21;
c30 = outc30;
c31 = outc31;
outc00 = pre_out;
outc01 = pre_out + 4;
outc10 = pre_out + 8;
outc11 = pre_out + 12;
outc20 = pre_out + 16;
outc21 = pre_out + 20;
outc30 = pre_out + 24;
outc31 = pre_out + 28;
}
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[din]], #32\n" /* load input*/
"ldp q2, q3, [%[din]], #32\n" /* load input*/
"fadd v15.4s, v0.4s, %[vbias].4s\n" /* add bias */
"fadd v16.4s, v1.4s, %[vbias].4s\n" /* add bias */
"ldp q4, q5, [%[din]], #32\n" /* load input*/
"fadd v17.4s, v2.4s, %[vbias].4s\n" /* add bias */
"fadd v18.4s, v3.4s, %[vbias].4s\n" /* add bias */
"ldp q6, q7, [%[din]]\n" /* load input*/
"fadd v19.4s, v4.4s, %[vbias].4s\n" /* add bias */
"fadd v20.4s, v5.4s, %[vbias].4s\n" /* add bias */
"fadd v21.4s, v6.4s, %[vbias].4s\n" /* add bias */
"fadd v22.4s, v7.4s, %[vbias].4s\n" /* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"str q15, [%[outc00]], #16\n" /* save outc00*/
"str q16, [%[outc01]], #16\n" /* save outc01*/
"str q17, [%[outc10]], #16\n" /* save outc10*/
"str q18, [%[outc11]], #16\n" /* save outc11*/
"str q19, [%[outc20]], #16\n" /* save outc20*/
"str q20, [%[outc21]], #16\n" /* save outc21*/
"str q21, [%[outc30]], #16\n" /* save outc30*/
"str q22, [%[outc31]], #16\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v15", "v16","v17","v18","v19","v20","v21","v22"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]!\n" /* load input*/
"vld1.32 {d4-d7}, [%[din]]!\n" /* load input*/
"vadd.f32 q0, q0, %q[vbias]\n" /* add bias */
"vadd.f32 q1, q1, %q[vbias]\n" /* add bias */
"vld1.32 {d8-d11}, [%[din]]!\n" /* load input*/
"vadd.f32 q2, q2, %q[vbias]\n" /* add bias */
"vadd.f32 q3, q3, %q[vbias]\n" /* add bias */
"vld1.32 {d12-d15}, [%[din]]!\n" /* load input*/
"vadd.f32 q4, q4, %q[vbias]\n" /* add bias */
"vadd.f32 q5, q5, %q[vbias]\n" /* add bias */
"vadd.f32 q6, q6, %q[vbias]\n" /* add bias */
"vadd.f32 q7, q7, %q[vbias]\n" /* add bias */
/* transpose */
"vtrn.32 q0, q1\n" /* r0: q0: a0a1c0c1, q1: b0b1d0d1*/
"vtrn.32 q2, q3\n" /* r0: q2: a2a3c2c3, q3: b2b3d2d3*/
"vtrn.32 q4, q5\n" /* r1: q4: a0a1c0c1, q5: b0b1d0d1*/
"vtrn.32 q6, q7\n" /* r1: q6: a2a3c2c3, q7: b2b3d2d3*/
"vswp d1, d4\n" /* r0: q0: a0a1a2a3, q2: c0c1c2c3*/
"vswp d3, d6\n" /* r0: q1: b0b1b2b3, q3: d0d1d2d3*/
"vswp d9, d12\n" /* r1: q4: a0a1a2a3, q6: c0c1c2c3*/
"vswp d11, d14\n" /* r1: q5: b0b1b2b3, q7: d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q15, #0\n"
"vmax.f32 q0, q0, q15\n"
"vmax.f32 q1, q1, q15\n"
"vmax.f32 q2, q2, q15\n"
"vmax.f32 q3, q3, q15\n"
"vmax.f32 q4, q4, q15\n"
"vmax.f32 q5, q5, q15\n"
"vmax.f32 q6, q6, q15\n"
"vmax.f32 q7, q7, q15\n"
"0:\n"
"vst1.32 {d0-d1}, [%[outc00]]!\n" /* save outc00*/
"vst1.32 {d2-d3}, [%[outc10]]!\n" /* save outc10*/
"vst1.32 {d4-d5}, [%[outc20]]!\n" /* save outc20*/
"vst1.32 {d6-d7}, [%[outc30]]!\n" /* save outc30*/
"vst1.32 {d8-d9}, [%[outc01]]!\n" /* save outc01*/
"vst1.32 {d10-d11}, [%[outc11]]!\n" /* save outc11*/
"vst1.32 {d12-d13}, [%[outc21]]!\n" /* save outc21*/
"vst1.32 {d14-d15}, [%[outc31]]!\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7", "q15"
);
#endif // __aarch64__
// clang-format on
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) {
for (int i = 0; i < remain; ++i) {
c00[i] = pre_out[i];
c01[i] = pre_out[i + 4];
c10[i] = pre_out[i + 8];
c11[i] = pre_out[i + 12];
c20[i] = pre_out[i + 16];
c21[i] = pre_out[i + 20];
c30[i] = pre_out[i + 24];
c31[i] = pre_out[i + 28];
}
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
}
}
}
......
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int win_round = wout_round + 2;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hin_r_block * hout_c_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(2 * win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(2 * win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -14,6 +14,7 @@
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
......@@ -26,592 +27,749 @@ namespace lite {
namespace arm {
namespace math {
void conv_depthwise_5x5s1_int8(int32_t* dout,
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const int* bias,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int chin,
const int hin,
const int win,
const int hout,
const int wout,
ARMContext* ctx,
PrecisionType out_type,
const float* scale);
void conv_depthwise_5x5_int8(const int8_t* din,
int32_t* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int8_t* weights,
const int32_t* bias,
const operators::ConvParam& param,
ARMContext* ctx,
PrecisionType out_type,
const float* scale) {
int stride_h = param.strides[0];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
// if (param.activation_param.has_active){
// if (param.activation_param.active == Active_relu ||
// fabs(param.activation_param.negative_slope) > 1e-6f){
// flag_relu = true;
// }
// }
if (stride_h == 1) {
#ifdef __aarch64__
conv_depthwise_5x5s1_int8(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
chin,
hin,
win,
hout,
wout,
ctx,
out_type,
scale);
#else
LOG(FATAL) << "5x5 dw conv armv7 has not impl";
#endif
}
}
/**
* \brief depthwise convolution, kernel size 5x5, stride 1, pad 1, with bias,
* width > 4
*/
// 2 line
#ifdef __aarch64__
template <typename Dtype>
inline void prefetch(const Dtype* din) {
#ifdef __aarch64__
asm volatile("PRFM PLDL1KEEP, [%[din]] \n" : : [din] "r"(din) : "memory");
#else
asm volatile("pld [%[din]] \n" : : [din] "r"(din) : "memory");
#endif
}
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx) {
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4;
const int hout_c_block = 8;
const int hout_r_kernel = 1;
const int wout_block = 4;
const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round + 4;
//! get h block
//! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 4
//! hin_r_block = hout_r_block + 4
int hout_r_block = (llc_size - 4 * win_round * hout_c_block * threads) /
(win_round * hout_c_block * threads +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
const int hin_r_block = hout_r_block + 4;
auto tmp_work_space = ctx->workspace_data<int8_t>();
int8_t ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(int8_t) * win_round);
Dtype ptr_write[wout_round]; // NOLINT
int in_len = win_round * hout_c_block;
int pre_in_size = hin_r_block * in_len;
pre_in_size = ROUNDUP(pre_in_size, 4);
int pre_out_size = hout_c_block * hout_r_block * wout_round;
int8_t* tmp_din = tmp_work_space;
void conv_depthwise_5x5s1_int8(
int32_t* dout,
const int8_t* din,
const int8_t* weights,
const int32_t* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int chin,
const int hin,
const int win,
const int hout,
const int wout,
ARMContext* ctx,
PrecisionType od_type,
float const* scales) { /// scale_size = channel-out
// printf("5*5 multiply\n");
int size_in_channel = win * hin;
int size_out_channel = wout * hout;
int w_stride = 5 * 5;
static int const stride_w = 1;
int const stride_h = stride_w;
int const chout = chin;
int const pad_w = 2;
int const pad_h = pad_w;
int const wout_round = ((wout + 7) / 8) * 8;
int const win_round = wout_round * stride_w + 5 - 1;
int const hout_round = ((hout + 2) / 3) * 3;
int const hin_round = hout_round * stride_h + 5 - 1;
int const tile_h = hout_round / 3;
int const tile_w = wout_round / 8;
int const pre_in_size = hin_round * win_round;
int const pre_out_size = hout_round * wout_round;
int const pre_io_size = pre_in_size + pre_out_size * sizeof(int);
int const hs = -pad_h;
int const he = hs + hin_round;
int const ws = -pad_w;
int const we = ws + win_round;
// signed char* tmp_work_space = new signed char [1024*5];
signed char* tmp_work_space = ctx->workspace_data<signed char>();
signed char* ptr_zero = tmp_work_space;
int* ptr_write = reinterpret_cast<int*>(ptr_zero + win_round);
signed char* pre_data =
reinterpret_cast<signed char*>(ptr_write + wout_round);
memset(ptr_zero, 0, win_round * sizeof(signed char));
int w_stride = 25; // kernel_w * kernel_h;
int ws = -padw;
int we = ws + win_round;
int w_loop = wout_round / 4;
int chout = chin;
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) {
signed char const* din_batch = din + n * chin * size_in_channel;
int* dout_batch = dout + n * chout * size_out_channel;
const int8_t* din_batch = din + n * chin * size_in_channel;
int8_t* dout_batch = reinterpret_cast<int8_t*>(dout) +
n * chout * size_out_channel * sizeof(Dtype);
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h - padh;
int he = hs + h_kernel + 4;
// #pragma omp parallel for
for (int c = 0; c < chout; c++) {
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP
int const thno = omp_get_thread_num();
int8_t* pre_din =
tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4);
int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
#else
int const thno = 0;
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
signed char const* din_channel = din_batch + c * size_in_channel;
signed char* pre_din = pre_data + thno * pre_io_size;
int* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
int* dout_ptr = pre_out;
prepack_input_nxw(din_channel,
pre_din,
c,
c + 1,
hs,
he,
ws,
we,
1,
win,
hin,
ptr_zero);
signed char const* wei_ptr = weights + c * w_stride;
int bias_val = flag_bias ? bias[c] : 0.f;
int8x8_t wr00 = vdup_n_s8(wei_ptr[0 * 5 + 0]);
int8x8_t wr01 = vdup_n_s8(wei_ptr[0 * 5 + 1]);
int8x8_t wr02 = vdup_n_s8(wei_ptr[0 * 5 + 2]);
int8x8_t wr03 = vdup_n_s8(wei_ptr[0 * 5 + 3]);
int8x8_t wr04 = vdup_n_s8(wei_ptr[0 * 5 + 4]);
int8x8_t wr10 = vdup_n_s8(wei_ptr[1 * 5 + 0]);
int8x8_t wr11 = vdup_n_s8(wei_ptr[1 * 5 + 1]);
int8x8_t wr12 = vdup_n_s8(wei_ptr[1 * 5 + 2]);
int8x8_t wr13 = vdup_n_s8(wei_ptr[1 * 5 + 3]);
int8x8_t wr14 = vdup_n_s8(wei_ptr[1 * 5 + 4]);
int8x8_t wr20 = vdup_n_s8(wei_ptr[2 * 5 + 0]);
int8x8_t wr21 = vdup_n_s8(wei_ptr[2 * 5 + 1]);
int8x8_t wr22 = vdup_n_s8(wei_ptr[2 * 5 + 2]);
int8x8_t wr23 = vdup_n_s8(wei_ptr[2 * 5 + 3]);
int8x8_t wr24 = vdup_n_s8(wei_ptr[2 * 5 + 4]);
int8x8_t wr30 = vdup_n_s8(wei_ptr[3 * 5 + 0]);
int8x8_t wr31 = vdup_n_s8(wei_ptr[3 * 5 + 1]);
int8x8_t wr32 = vdup_n_s8(wei_ptr[3 * 5 + 2]);
int8x8_t wr33 = vdup_n_s8(wei_ptr[3 * 5 + 3]);
int8x8_t wr34 = vdup_n_s8(wei_ptr[3 * 5 + 4]);
int8x8_t wr40 = vdup_n_s8(wei_ptr[4 * 5 + 0]);
int8x8_t wr41 = vdup_n_s8(wei_ptr[4 * 5 + 1]);
int8x8_t wr42 = vdup_n_s8(wei_ptr[4 * 5 + 2]);
int8x8_t wr43 = vdup_n_s8(wei_ptr[4 * 5 + 3]);
int8x8_t wr44 = vdup_n_s8(wei_ptr[4 * 5 + 4]);
int* doutr0 = nullptr;
int* doutr1 = nullptr;
int* doutr2 = nullptr;
signed char const* dr0 = pre_din;
signed char const* dr1 = dr0 + win_round;
signed char const* dr2 = dr1 + win_round;
signed char const* dr3 = dr2 + win_round;
signed char const* dr4 = dr3 + win_round;
signed char const* dr5 = dr4 + win_round;
signed char const* dr6 = dr5 + win_round;
signed char const* din_ptr0 = nullptr;
signed char const* din_ptr1 = nullptr;
signed char const* din_ptr2 = nullptr;
signed char const* din_ptr3 = nullptr;
signed char const* din_ptr4 = nullptr;
signed char const* din_ptr5 = nullptr;
signed char const* din_ptr6 = nullptr;
for (int h = 0; h < tile_h; h++) {
// printf("c:%d h:%d\n", c, h);
doutr0 = dout_ptr;
doutr1 = doutr0 + wout_round;
doutr2 = doutr1 + wout_round;
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
din_ptr6 = dr6;
prefetch(doutr0);
prefetch(doutr1);
prefetch(doutr2);
prefetch(din_ptr0);
prefetch(din_ptr1);
prefetch(din_ptr2);
prefetch(din_ptr3);
prefetch(din_ptr4);
prefetch(din_ptr5);
prefetch(din_ptr6);
for (int j = 0; j < tile_w; ++j) {
// printf("j:%d\n", j);
int32x4_t voutr00 = vdupq_n_s32(bias_val);
int32x4_t voutr01 = vdupq_n_s32(bias_val);
int32x4_t voutr10 = vdupq_n_s32(bias_val);
int32x4_t voutr11 = vdupq_n_s32(bias_val);
int32x4_t voutr20 = vdupq_n_s32(bias_val);
int32x4_t voutr21 = vdupq_n_s32(bias_val);
// din data
int8x8_t vinr00 = vld1_s8(din_ptr0 + 0);
int8x8_t vinr01 = vld1_s8(din_ptr0 + 8);
int8x8_t vinr10 = vld1_s8(din_ptr1 + 0);
int8x8_t vinr11 = vld1_s8(din_ptr1 + 8);
int8x8_t vinr20 = vld1_s8(din_ptr2 + 0);
int8x8_t vinr21 = vld1_s8(din_ptr2 + 8);
int8x8_t vinr30 = vld1_s8(din_ptr3 + 0);
int8x8_t vinr31 = vld1_s8(din_ptr3 + 8);
int8x8_t vinr40 = vld1_s8(din_ptr4 + 0);
int8x8_t vinr41 = vld1_s8(din_ptr4 + 8);
int8x8_t vinr50 = vld1_s8(din_ptr5 + 0);
int8x8_t vinr51 = vld1_s8(din_ptr5 + 8);
int8x8_t vinr60 = vld1_s8(din_ptr6 + 0);
int8x8_t vinr61 = vld1_s8(din_ptr6 + 8);
/// the first row
// r0
int8x8_t vtmp1 = vext_s8(vinr00, vinr01, 1); // 12345678
int8x8_t vtmp2 = vext_s8(vinr00, vinr01, 2); // 2345678
int8x8_t vtmp3 = vext_s8(vinr00, vinr01, 3); // 345678
int8x8_t vtmp4 = vext_s8(vinr00, vinr01, 4); // 45678
int16x8_t tvoutr0 = vmull_s8(vinr00, wr00);
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr01);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp2, wr02);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr03);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp4, wr04);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
// r1
vtmp1 = vext_s8(vinr10, vinr11, 1); // 12345678
vtmp2 = vext_s8(vinr10, vinr11, 2); // 2345678
vtmp3 = vext_s8(vinr10, vinr11, 3); // 345678
vtmp4 = vext_s8(vinr10, vinr11, 4); // 45678
tvoutr0 = vmull_s8(vinr10, wr10);
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr11);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp2, wr12);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr13);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp4, wr14);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
int16x8_t tvoutr1 = vmull_s8(vinr10, wr00);
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr01);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp2, wr02);
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr03);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp4, wr04);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
// r2
vtmp1 = vext_s8(vinr20, vinr21, 1); // 12345678
vtmp2 = vext_s8(vinr20, vinr21, 2); // 2345678
vtmp3 = vext_s8(vinr20, vinr21, 3); // 345678
vtmp4 = vext_s8(vinr20, vinr21, 4); // 45678
tvoutr0 = vmull_s8(vinr20, wr20);
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr21);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp2, wr22);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr23);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp4, wr24);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr1 = vmull_s8(vinr20, wr10);
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr11);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp2, wr12);
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr13);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp4, wr14);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
int16x8_t tvoutr2 = vmull_s8(vinr20, wr00);
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr01);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp2, wr02);
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr03);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp4, wr04);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
// r3
vtmp1 = vext_s8(vinr30, vinr31, 1); // 12345678
vtmp2 = vext_s8(vinr30, vinr31, 2); // 2345678
vtmp3 = vext_s8(vinr30, vinr31, 3); // 345678
vtmp4 = vext_s8(vinr30, vinr31, 4); // 45678
tvoutr0 = vmull_s8(vinr30, wr30);
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr31);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp2, wr32);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr33);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp4, wr34);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr1 = vmull_s8(vinr30, wr20);
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr21);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp2, wr22);
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr23);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp4, wr24);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr2 = vmull_s8(vinr30, wr10);
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr11);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp2, wr12);
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr13);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp4, wr14);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
// r4
vtmp1 = vext_s8(vinr40, vinr41, 1); // 12345678
vtmp2 = vext_s8(vinr40, vinr41, 2); // 2345678
vtmp3 = vext_s8(vinr40, vinr41, 3); // 345678
vtmp4 = vext_s8(vinr40, vinr41, 4); // 45678
tvoutr0 = vmull_s8(vinr40, wr40);
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr41);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp2, wr42);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr43);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr0 = vmull_s8(vtmp4, wr44);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0));
tvoutr1 = vmull_s8(vinr40, wr30);
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr31);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp2, wr32);
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr33);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp4, wr34);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr2 = vmull_s8(vinr40, wr20);
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr21);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp2, wr22);
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr23);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp4, wr24);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
// r5
vtmp1 = vext_s8(vinr50, vinr51, 1); // 12345678
vtmp2 = vext_s8(vinr50, vinr51, 2); // 2345678
vtmp3 = vext_s8(vinr50, vinr51, 3); // 345678
vtmp4 = vext_s8(vinr50, vinr51, 4); // 45678
tvoutr1 = vmull_s8(vinr50, wr40);
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr41);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp2, wr42);
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr43);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr1 = vmull_s8(vtmp4, wr44);
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1));
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1));
tvoutr2 = vmull_s8(vinr50, wr30);
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr31);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp2, wr32);
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr33);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp4, wr34);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
// r6
vtmp1 = vext_s8(vinr60, vinr61, 1); // 12345678
vtmp2 = vext_s8(vinr60, vinr61, 2); // 2345678
vtmp3 = vext_s8(vinr60, vinr61, 3); // 345678
vtmp4 = vext_s8(vinr60, vinr61, 4); // 45678
tvoutr2 = vmull_s8(vinr60, wr40);
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr41);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp2, wr42);
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr43);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
tvoutr2 = vmull_s8(vtmp4, wr44);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
/// data shift 8 bytes
din_ptr0 += 8;
din_ptr1 += 8;
din_ptr2 += 8;
din_ptr3 += 8;
din_ptr4 += 8;
din_ptr5 += 8;
din_ptr6 += 8;
/// store
vst1q_s32(doutr0, voutr00);
vst1q_s32(doutr1, voutr10);
vst1q_s32(doutr2, voutr20);
doutr0 += 4;
doutr1 += 4;
doutr2 += 4;
vst1q_s32(doutr0, voutr01);
vst1q_s32(doutr1, voutr11);
vst1q_s32(doutr2, voutr21);
doutr0 += 4;
doutr1 += 4;
doutr2 += 4;
} /// end of tile_w
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
dr3 = dr6;
dr4 = dr3 + win_round;
dr5 = dr4 + win_round;
dr6 = dr5 + win_round;
dout_ptr = dout_ptr + 3 * wout_round;
} /// end of tile_h
if (scales == 0) {
write_to_output_numc(pre_out,
dout_batch,
1,
hout_round,
c,
c + 1,
0,
hout,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
ptr_write);
} else if (od_type == PRECISION(kFloat)) {
write2_to_output_numc(pre_out,
reinterpret_cast<float*>(dout_batch),
1,
hout_round,
c,
c + 1,
0,
hout,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<float*>(ptr_write),
scales);
} else if (od_type == PRECISION(kInt8)) {
write2_to_output_numc(pre_out,
reinterpret_cast<signed char*>(dout_batch),
1,
hout_round,
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
const int8_t* block_inr3 = block_inr2 + in_len;
const int8_t* block_inr4 = block_inr3 + in_len;
const int8_t* weight_c = weights + c * w_stride;
float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (flag_bias) {
bias_local[0] = bias[c];
bias_local[1] = bias[c + 1];
bias_local[2] = bias[c + 2];
bias_local[3] = bias[c + 3];
bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
bias_local[6] = bias[c + 6];
bias_local[7] = bias[c + 7];
}
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
int cnt = w_loop;
const int8_t* inr0 = block_inr0;
const int8_t* inr1 = block_inr1;
const int8_t* inr2 = block_inr2;
const int8_t* inr3 = block_inr3;
const int8_t* inr4 = block_inr4;
int32_t* ptr_out0 = pre_out + hk * out_row_stride;
// clang-format off
#ifdef __aarch64__
auto wptr = weight_c;
asm volatile(
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */
"ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */
"1:\n"
/* in r0 */
"smull v20.8h, v0.8b, v8.8b\n" /* w0, int16, out0 */
"smull v21.8h, v1.8b, v8.8b\n" /* w0, int16, out1 */
"smull v22.8h, v2.8b, v8.8b\n" /* w0, int16, out2 */
"smull v23.8h, v3.8b, v8.8b\n" /* w0, int16, out3 */
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]]\n" /* load r0 4-7 */
"smlal v20.8h, v1.8b, v9.8b\n" /* w1, int16, out0 */
"smlal v21.8h, v2.8b, v9.8b\n" /* w1, int16, out1 */
"smlal v22.8h, v3.8b, v9.8b\n" /* w1, int16, out2 */
"smlal v23.8h, v4.8b, v9.8b\n" /* w1, int16, out3 */
"sxtl v24.4s, v20.4h\n" /* mov to out0 low */
"sxtl2 v25.4s, v20.8h\n" /* mov to out0 hig */
"sxtl v26.4s, v21.4h\n" /* mov to out1 low */
"sxtl2 v27.4s, v21.8h\n" /* mov to out1 hig */
"sxtl v28.4s, v22.4h\n" /* mov to out2 low */
"sxtl2 v29.4s, v22.8h\n" /* mov to out2 hig */
"sxtl v30.4s, v23.4h\n" /* mov to out3 low */
"sxtl2 v31.4s, v23.8h\n" /* mov to out3 hig */
"ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 4-7 */
"smull v20.8h, v2.8b, v10.8b\n" /* w2, int16, out0 */
"smull v21.8h, v3.8b, v10.8b\n" /* w2, int16, out1 */
"smull v22.8h, v4.8b, v10.8b\n" /* w2, int16, out2 */
"smull v23.8h, v5.8b, v10.8b\n" /* w2, int16, out3 */
"smlal v20.8h, v3.8b, v11.8b\n" /* w3, int16, out0 */
"smlal v21.8h, v4.8b, v11.8b\n" /* w3, int16, out1 */
"smlal v22.8h, v5.8b, v11.8b\n" /* w3, int16, out2 */
"smlal v23.8h, v6.8b, v11.8b\n" /* w3, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r1]], #32\n" /* load r1 0-3 */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v4.8b, v12.8b\n" /* w4, int16, out0 */
"smull v21.8h, v5.8b, v12.8b\n" /* w4, int16, out1 */
"smull v22.8h, v6.8b, v12.8b\n" /* w4, int16, out2 */
"smull v23.8h, v7.8b, v12.8b\n" /* w4, int16, out3 */
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r1]]\n" /* load r1 4-7 */
/* in r1 */
"smlal v20.8h, v0.8b, v13.8b\n" /* w5, int16, out0 */
"smlal v21.8h, v1.8b, v13.8b\n" /* w5, int16, out1 */
"smlal v22.8h, v2.8b, v13.8b\n" /* w5, int16, out2 */
"smlal v23.8h, v3.8b, v13.8b\n" /* w5, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 8-11 */
"smull v20.8h, v1.8b, v14.8b\n" /* w6, int16, out0 */
"smull v21.8h, v2.8b, v14.8b\n" /* w6, int16, out1 */
"smull v22.8h, v3.8b, v14.8b\n" /* w6, int16, out2 */
"smull v23.8h, v4.8b, v14.8b\n" /* w6, int16, out3 */
"smlal v20.8h, v2.8b, v15.8b\n" /* w7, int16, out0 */
"smlal v21.8h, v3.8b, v15.8b\n" /* w7, int16, out1 */
"smlal v22.8h, v4.8b, v15.8b\n" /* w7, int16, out2 */
"smlal v23.8h, v5.8b, v15.8b\n" /* w7, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v3.8b, v16.8b\n" /* w8, int16, out0 */
"smull v21.8h, v4.8b, v16.8b\n" /* w8, int16, out1 */
"smull v22.8h, v5.8b, v16.8b\n" /* w8, int16, out2 */
"smull v23.8h, v6.8b, v16.8b\n" /* w8, int16, out3 */
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r2]], #32\n" /* load r2 0-3 */
"smlal v20.8h, v4.8b, v17.8b\n" /* w9, int16, out0 */
"smlal v21.8h, v5.8b, v17.8b\n" /* w9, int16, out1 */
"smlal v22.8h, v6.8b, v17.8b\n" /* w9, int16, out2 */
"smlal v23.8h, v7.8b, v17.8b\n" /* w9, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r2]]\n" /* load r2 4-7 */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
/* in r2 */
"smull v20.8h, v0.8b, v18.8b\n" /* w10, int16, out0 */
"smull v21.8h, v1.8b, v18.8b\n" /* w10, int16, out1 */
"smull v22.8h, v2.8b, v18.8b\n" /* w10, int16, out2 */
"smull v23.8h, v3.8b, v18.8b\n" /* w10, int16, out3 */
"smlal v20.8h, v1.8b, v19.8b\n" /* w11, int16, out0 */
"smlal v21.8h, v2.8b, v19.8b\n" /* w11, int16, out1 */
"smlal v22.8h, v3.8b, v19.8b\n" /* w11, int16, out2 */
"smlal v23.8h, v4.8b, v19.8b\n" /* w11, int16, out3 */
"ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 12-15 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v2.8b, v8.8b\n" /* w12, int16, out0 */
"smull v21.8h, v3.8b, v8.8b\n" /* w12, int16, out1 */
"smull v22.8h, v4.8b, v8.8b\n" /* w12, int16, out2 */
"smull v23.8h, v5.8b, v8.8b\n" /* w12, int16, out3 */
"smlal v20.8h, v3.8b, v9.8b\n" /* w13, int16, out0 */
"smlal v21.8h, v4.8b, v9.8b\n" /* w13, int16, out1 */
"smlal v22.8h, v5.8b, v9.8b\n" /* w13, int16, out2 */
"smlal v23.8h, v6.8b, v9.8b\n" /* w13, int16, out3 */
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r3]], #32\n" /* load r3 0-3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v4.8b, v10.8b\n" /* w14, int16, out0 */
"smull v21.8h, v5.8b, v10.8b\n" /* w14, int16, out1 */
"smull v22.8h, v6.8b, v10.8b\n" /* w14, int16, out2 */
"smull v23.8h, v7.8b, v10.8b\n" /* w14, int16, out3 */
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r3]]\n" /* load r3 4-7 */
/* in r3 */
"smlal v20.8h, v0.8b, v11.8b\n" /* w15, int16, out0 */
"smlal v21.8h, v1.8b, v11.8b\n" /* w15, int16, out1 */
"smlal v22.8h, v2.8b, v11.8b\n" /* w15, int16, out2 */
"smlal v23.8h, v3.8b, v11.8b\n" /* w15, int16, out3 */
"ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 16-19 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v1.8b, v12.8b\n" /* w16, int16, out0 */
"smull v21.8h, v2.8b, v12.8b\n" /* w16, int16, out1 */
"smull v22.8h, v3.8b, v12.8b\n" /* w16, int16, out2 */
"smull v23.8h, v4.8b, v12.8b\n" /* w16, int16, out3 */
"ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 20-23 */
"smlal v20.8h, v2.8b, v13.8b\n" /* w17, int16, out0 */
"smlal v21.8h, v3.8b, v13.8b\n" /* w17, int16, out1 */
"smlal v22.8h, v4.8b, v13.8b\n" /* w17, int16, out2 */
"smlal v23.8h, v5.8b, v13.8b\n" /* w17, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"smull v20.8h, v3.8b, v14.8b\n" /* w18, int16, out0 */
"smull v21.8h, v4.8b, v14.8b\n" /* w18, int16, out1 */
"smull v22.8h, v5.8b, v14.8b\n" /* w18, int16, out2 */
"smull v23.8h, v6.8b, v14.8b\n" /* w18, int16, out3 */
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r4]], #32\n" /* load r4 0-3 */
"smlal v20.8h, v4.8b, v15.8b\n" /* w19, int16, out0 */
"smlal v21.8h, v5.8b, v15.8b\n" /* w19, int16, out1 */
"smlal v22.8h, v6.8b, v15.8b\n" /* w19, int16, out2 */
"smlal v23.8h, v7.8b, v15.8b\n" /* w19, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r4]]\n" /* load r4 4-7 */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
/* in r4 */
"smull v20.8h, v0.8b, v16.8b\n" /* w20, int16, out0 */
"smull v21.8h, v1.8b, v16.8b\n" /* w20, int16, out1 */
"smull v22.8h, v2.8b, v16.8b\n" /* w20, int16, out2 */
"smull v23.8h, v3.8b, v16.8b\n" /* w20, int16, out3 */
"smlal v20.8h, v1.8b, v17.8b\n" /* w21, int16, out0 */
"smlal v21.8h, v2.8b, v17.8b\n" /* w21, int16, out1 */
"smlal v22.8h, v3.8b, v17.8b\n" /* w21, int16, out2 */
"smlal v23.8h, v4.8b, v17.8b\n" /* w21, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"ld1 {v12.8b}, [%[wc]], #8\n" /* load wc 24 */
"smull v20.8h, v2.8b, v18.8b\n" /* w22, int16, out0 */
"smull v21.8h, v3.8b, v18.8b\n" /* w22, int16, out1 */
"smull v22.8h, v4.8b, v18.8b\n" /* w22, int16, out2 */
"smull v23.8h, v5.8b, v18.8b\n" /* w22, int16, out3 */
"smlal v20.8h, v3.8b, v19.8b\n" /* w23, int16, out0 */
"smlal v21.8h, v4.8b, v19.8b\n" /* w23, int16, out1 */
"smlal v22.8h, v5.8b, v19.8b\n" /* w23, int16, out2 */
"smlal v23.8h, v6.8b, v19.8b\n" /* w23, int16, out3 */
"ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */
"sub %[wc], %[wc], #200 \n"
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */
"smull v20.8h, v4.8b, v12.8b\n" /* w24, int16, out0 */
"smull v21.8h, v5.8b, v12.8b\n" /* w24, int16, out1 */
"smull v22.8h, v6.8b, v12.8b\n" /* w24, int16, out2 */
"smull v23.8h, v7.8b, v12.8b\n" /* w24, int16, out3 */
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"stp q24, q25, [%[ptr_out0]], #32\n"
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"stp q26, q27, [%[ptr_out0]], #32\n"
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
"subs %w[cnt], %w[cnt], #1\n"
"stp q28, q29, [%[ptr_out0]], #32\n"
"stp q30, q31, [%[ptr_out0]], #32\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[wc] "+r"(wptr),
[ptr_out0] "+r"(ptr_out0)
:
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12","v13",
"v14","v15","v16","v17","v18","v19",
"v20","v21","v22","v23","v24","v25",
"v26","v27","v28","v29","v30","v31"
);
#else
auto wptr = weight_c;
asm volatile(
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"1:\n"
/* inr0 */
"vmull.s8 q4, d0, d6\n" /* int16, out0 */
"vmull.s8 q5, d1, d6\n" /* int16, out1 */
"vmull.s8 q6, d2, d6\n" /* int16, out2 */
"vmull.s8 q7, d3, d6\n" /* int16, out3 */
"vmlal.s8 q4, d1, d7\n" /* int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* int16, out3 */
"vmovl.s16 q8, d8\n" /* mov to out0 low */
"vmovl.s16 q9, d9\n" /* mov to out0 hig */
"vmovl.s16 q10, d10\n" /* mov to out1 low */
"vmovl.s16 q11, d11\n" /* mov to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */
"vmovl.s16 q12, d12\n" /* mov to out2 low */
"vmovl.s16 q13, d13\n" /* mov to out2 hig */
"vmovl.s16 q14, d14\n" /* mov to out3 low */
"vmovl.s16 q15, d15\n" /* mov to out3 hig */
"vld1.32 {d0-d1}, [%[r0]]\n" /* load r0, 6-7 */
"vmull.s8 q4, d2, d6\n" /* w2, int16, out0 */
"vmull.s8 q5, d3, d6\n" /* w2, int16, out1 */
"vmull.s8 q6, d4, d6\n" /* w2, int16, out2 */
"vmull.s8 q7, d5, d6\n" /* w2, int16, out3 */
"vmlal.s8 q4, d3, d7\n" /* w3, int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* w3, int16, out1 */
"vmlal.s8 q6, d5, d7\n" /* w3, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w3, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */
"sub %[r0], %[r0], #16\n" /* r0 = r0 - 16 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d4, d6\n" /* w4, int16, out0 */
"vmull.s8 q5, d5, d6\n" /* w4, int16, out1 */
"vmull.s8 q6, d0, d6\n" /* w4, int16, out2 */
"vmull.s8 q7, d1, d6\n" /* w4, int16, out3 */
"vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */
/* inr1 */
"vmlal.s8 q4, d0, d7\n" /* w5, int16, out0 */
"vmlal.s8 q5, d1, d7\n" /* w5, int16, out1 */
"vmlal.s8 q6, d2, d7\n" /* w5, int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* w5, int16, out3 */
"vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4-5 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d1, d6\n" /* w6, int16, out0 */
"vmull.s8 q5, d2, d6\n" /* w6, int16, out1 */
"vmull.s8 q6, d3, d6\n" /* w6, int16, out2 */
"vmull.s8 q7, d4, d6\n" /* w6, int16, out3 */
"vld1.32 {d0-d1}, [%[r1]]\n" /* load r1, 6-7 */
"vmlal.s8 q4, d2, d7\n" /* w7, int16, out0 */
"vmlal.s8 q5, d3, d7\n" /* w7, int16, out1 */
"vmlal.s8 q6, d4, d7\n" /* w7, int16, out2 */
"vmlal.s8 q7, d5, d7\n" /* w7, int16, out3 */
"sub %[r1], %[r1], #16\n" /* r0 = r0 - 16 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w8-w9 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d3, d6\n" /* w8, int16, out0 */
"vmull.s8 q5, d4, d6\n" /* w8, int16, out1 */
"vmull.s8 q6, d5, d6\n" /* w8, int16, out2 */
"vmull.s8 q7, d0, d6\n" /* w8, int16, out3 */
"vmlal.s8 q4, d4, d7\n" /* w9, int16, out0 */
"vmlal.s8 q5, d5, d7\n" /* w9, int16, out1 */
"vmlal.s8 q6, d0, d7\n" /* w9, int16, out2 */
"vmlal.s8 q7, d1, d7\n" /* w9, int16, out3 */
"vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w10-w11 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */
/* inr2 */
"vmull.s8 q4, d0, d6\n" /* w10, int16, out0 */
"vmull.s8 q5, d1, d6\n" /* w10, int16, out1 */
"vmull.s8 q6, d2, d6\n" /* w10, int16, out2 */
"vmull.s8 q7, d3, d6\n" /* w10, int16, out3 */
"vmlal.s8 q4, d1, d7\n" /* w11, int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* w11, int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* w11, int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* w11, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w12-w13 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d0-d1}, [%[r2]]\n" /* load r2, 6-7 */
"vmull.s8 q4, d2, d6\n" /* w12, int16, out0 */
"vmull.s8 q5, d3, d6\n" /* w12, int16, out1 */
"vmull.s8 q6, d4, d6\n" /* w12, int16, out2 */
"vmull.s8 q7, d5, d6\n" /* w12, int16, out3 */
"vmlal.s8 q4, d3, d7\n" /* w13, int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* w13, int16, out1 */
"vmlal.s8 q6, d5, d7\n" /* w13, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w13, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w14-w15 */
"sub %[r2], %[r2], #16\n" /* r2 = r2 - 16 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d4, d6\n" /* w14, int16, out0 */
"vmull.s8 q5, d5, d6\n" /* w14, int16, out1 */
"vmull.s8 q6, d0, d6\n" /* w14, int16, out2 */
"vmull.s8 q7, d1, d6\n" /* w14, int16, out3 */
"vld1.32 {d0-d3}, [%[r3]]!\n" /* load r3, 0-3 */
/* inr3 */
"vmlal.s8 q4, d0, d7\n" /* w15, int16, out0 */
"vmlal.s8 q5, d1, d7\n" /* w15, int16, out1 */
"vmlal.s8 q6, d2, d7\n" /* w15, int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* w15, int16, out3 */
"vld1.32 {d4-d5}, [%[r3]]!\n" /* load r3, 4-5 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w16-w17 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d1, d6\n" /* w16, int16, out0 */
"vmull.s8 q5, d2, d6\n" /* w16, int16, out1 */
"vmull.s8 q6, d3, d6\n" /* w16, int16, out2 */
"vmull.s8 q7, d4, d6\n" /* w16, int16, out3 */
"vld1.32 {d0-d1}, [%[r3]]\n" /* load r3, 6-7 */
"vmlal.s8 q4, d2, d7\n" /* w17, int16, out0 */
"vmlal.s8 q5, d3, d7\n" /* w17, int16, out1 */
"vmlal.s8 q6, d4, d7\n" /* w17, int16, out2 */
"vmlal.s8 q7, d5, d7\n" /* w17, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w18-w19 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"sub %[r3], %[r3], #16\n" /* r3 = r3 - 16 */
"vmull.s8 q4, d3, d6\n" /* w18, int16, out0 */
"vmull.s8 q5, d4, d6\n" /* w18, int16, out1 */
"vmull.s8 q6, d5, d6\n" /* w18, int16, out2 */
"vmull.s8 q7, d0, d6\n" /* w18, int16, out3 */
"vmlal.s8 q4, d4, d7\n" /* w19, int16, out0 */
"vmlal.s8 q5, d5, d7\n" /* w19, int16, out1 */
"vmlal.s8 q6, d0, d7\n" /* w19, int16, out2 */
"vmlal.s8 q7, d1, d7\n" /* w19, int16, out3 */
"vld1.32 {d0-d3}, [%[r4]]!\n" /* load r4, 0-3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w20-w21 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d4-d5}, [%[r4]]!\n" /* load r4, 4-5 */
/* inr4 */
"vmull.s8 q4, d0, d6\n" /* w20, int16, out0 */
"vmull.s8 q5, d1, d6\n" /* w20, int16, out1 */
"vmull.s8 q6, d2, d6\n" /* w20, int16, out2 */
"vmull.s8 q7, d3, d6\n" /* w20, int16, out3 */
"vmlal.s8 q4, d1, d7\n" /* w21, int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* w21, int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* w21, int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* w21, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w22-w23 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d0-d1}, [%[r4]]\n" /* load r4, 5-6 */
"vmull.s8 q4, d2, d6\n" /* w22, int16, out0 */
"vmull.s8 q5, d3, d6\n" /* w22, int16, out1 */
"vmull.s8 q6, d4, d6\n" /* w22, int16, out2 */
"vmull.s8 q7, d5, d6\n" /* w22, int16, out3 */
"vmlal.s8 q4, d3, d7\n" /* w23, int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* w23, int16, out1 */
"vmlal.s8 q6, d5, d7\n" /* w23, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w23, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w24 */
"sub %[r4], %[r4], #16\n" /* r4 = r4 - 16 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"sub %[wptr], %[wptr], #200 \n" /* wptr = wptr - 200 */
"vmull.s8 q4, d4, d6\n" /* w22, int16, out0 */
"vmull.s8 q5, d5, d6\n" /* w22, int16, out1 */
"vmull.s8 q6, d0, d6\n" /* w22, int16, out2 */
"vmull.s8 q7, d1, d6\n" /* w22, int16, out3 */
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 0-3 */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vst1.32 {d16-d19}, [%[ptr_out0]]!\n"/* store out0 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vst1.32 {d20-d23}, [%[ptr_out0]]!\n"/*store out1 */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"subs %[cnt], #1\n" /* cnt = cnt - 1 */
"vst1.32 {d24-d27}, [%[ptr_out0]]!\n"/* store out2 */
"vst1.32 {d28-d31}, [%[ptr_out0]]!\n"/* store out3 */
"bne 1b\n" /* branch main loop */
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[ptr_out0] "+r"(ptr_out0),
[wptr] "+r"(wptr)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
// clang-format on
int32_t* ptr_tmp = ptr_out0 - w_loop * 32;
block_inr0 = block_inr1;
block_inr1 = block_inr2;
block_inr2 = block_inr3;
block_inr3 = block_inr4;
block_inr4 = block_inr3 + in_len;
}
write_int32_nchwc8_to_nchw<Dtype>(pre_out,
reinterpret_cast<Dtype*>(dout_batch),
c,
c + 1,
0,
hout,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
reinterpret_cast<signed char*>(ptr_write),
scales);
bias_local,
flag_bias,
ptr_write,
scale + c);
}
}
}
// else if (od_type == AK_INT32) {
// write2_to_output_numc(pre_out, (int*)dout_batch, 1, hout_round, c,
// c+1,
// 0, hout, 0, wout_round, chout, hout, wout, flag_relu,
// (int*)ptr_write, scales);
// }
} /// end of chout
} /// end of batch num
}
#endif // __aarch64__
template void conv_depthwise_5x5s1_int8<int8_t>(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template void conv_depthwise_5x5s1_int8<float>(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -389,10 +389,9 @@ inline void prepack_input_nxwc4_dw(const float* din,
}
}
inline void prepack_input_nxw_c8_int8(const int8_t* din,
inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
int8_t* dout,
int cs,
int ce,
int hs,
int he,
int ws,
......@@ -402,95 +401,61 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
int height) {
int n = he - hs;
if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0";
return;
LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero";
}
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t);
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w));
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c,
ptr_c + size_w,
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
int size_c = width * height;
int valid_cnt = valid_w >> 3;
int remain = valid_w & 7;
int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8;
int remain = size_w - loop * 8;
for (int c = cs; c < ce; c += 8) {
auto din_c = din + c * size_channel_in;
for (int j = 0; j < 8; ++j) {
ptr_r[j] = ptr_c_ori[j];
}
//! valid channel
if (c + 8 > channel) {
switch (c + 8 - channel) {
for (int h = hs; h < he; ++h) {
const int8_t* ptr_c0 = din + h * width + cs * size_c;
const int8_t* ptr_c1 = ptr_c0 + size_c;
const int8_t* ptr_c2 = ptr_c1 + size_c;
const int8_t* ptr_c3 = ptr_c2 + size_c;
const int8_t* ptr_c4 = ptr_c3 + size_c;
const int8_t* ptr_c5 = ptr_c4 + size_c;
const int8_t* ptr_c6 = ptr_c5 + size_c;
const int8_t* ptr_c7 = ptr_c6 + size_c;
if (h < 0 || h >= height) {
memset(dout, 0, 8 * size_w * sizeof(int8_t));
dout += size_w * 8;
continue;
} else if (cs + 8 > channel) {
switch (cs + 8 - channel) {
case 7:
ptr_r[1] = zero_ptr;
ptr_c1 = zero_ptr;
case 6:
ptr_r[2] = zero_ptr;
ptr_c2 = zero_ptr;
case 5:
ptr_r[3] = zero_ptr;
ptr_c3 = zero_ptr;
case 4:
ptr_r[4] = zero_ptr;
ptr_c4 = zero_ptr;
case 3:
ptr_r[5] = zero_ptr;
ptr_c5 = zero_ptr;
case 2:
ptr_r[6] = zero_ptr;
ptr_c6 = zero_ptr;
case 1:
ptr_r[7] = zero_ptr;
ptr_c7 = zero_ptr;
default:
break;
}
}
//! valid height
int j = 0;
for (int i = hs; i < he; i++) {
auto din_r = din_c + i * width;
for (int k = 0; k < 8; ++k) {
if (ptr_r[k] != zero_ptr) {
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
if (pad_l) {
memset(dout, 0, pad_l * 8 * sizeof(int8_t));
dout += pad_l * 8;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
if (valid_cnt) {
int cnt = valid_cnt;
#ifdef __aarch64__
asm volatile(
/* main loop */
......@@ -534,15 +499,15 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
"stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc",
"memory",
......@@ -591,35 +556,35 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
"vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif // aarch64
#endif // __aarch64__
}
for (int k = 0; k < remain; ++k) {
ptr_out[0] = *(inr0++);
ptr_out[1] = *(inr1++);
ptr_out[2] = *(inr2++);
ptr_out[3] = *(inr3++);
ptr_out[4] = *(inr4++);
ptr_out[5] = *(inr5++);
ptr_out[6] = *(inr6++);
ptr_out[7] = *(inr7++);
ptr_out += 8;
for (int i = 0; i < remain; ++i) {
dout[0] = *(ptr_c0++);
dout[1] = *(ptr_c1++);
dout[2] = *(ptr_c2++);
dout[3] = *(ptr_c3++);
dout[4] = *(ptr_c4++);
dout[5] = *(ptr_c5++);
dout[6] = *(ptr_c6++);
dout[7] = *(ptr_c7++);
dout += 8;
}
j++;
if (pad_r) {
memset(dout, 0, pad_r * 8 * sizeof(int8_t));
dout += pad_r * 8;
}
}
TargetFree(TARGET(kARM), ptr_c);
}
/*wirte result in outputs
......
......@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din,
bool flag_relu,
ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din,
}
}
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......
......@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32";
/// trans weights
// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
......@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......
......@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
}
#endif /// 3x3dw
#if 0 /// 5x5dw
#if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
......@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册