提交 498a30cf 编写于 作者: Y yiicy 提交者: Xiaoyang LI

improve dw conv performance

*  imporve prepack_input func speed in int8 3x3s1 dw conv

* fix code style

* fix code style

* improve 3x3s1 dw fp32 conv speed a little

* arm add 5x5s1 int8 dw conv, test=develop
上级 499fa1b8
...@@ -145,14 +145,17 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -145,14 +145,17 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21 = ptr_write; outc21 = ptr_write;
outc31 = ptr_write; outc31 = ptr_write;
} }
auto c00 = outc00; float* outl[] = {outc00,
auto c01 = outc01; outc10,
auto c10 = outc10; outc20,
auto c11 = outc11; outc30,
auto c20 = outc20; outc01,
auto c21 = outc21; outc11,
auto c30 = outc30; outc21,
auto c31 = outc31; outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(flag_relu)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) { for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain; bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out; float* out0 = pre_out;
...@@ -210,6 +213,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -210,6 +213,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/
"ldp x0, x1, [%[outl]] \n"
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/
...@@ -230,6 +234,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -230,6 +234,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/
"ldp x2, x3, [%[outl], #16] \n"
"fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/
...@@ -239,6 +244,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -239,6 +244,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/
"ldp x4, x5, [%[outl], #32] \n"
"fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/
...@@ -248,25 +254,83 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -248,25 +254,83 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/
"ldp x6, x7, [%[outl], #48] \n"
"fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/
"fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/
"fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/
/* save result */
"stp q15, q16, [%[out]], #32\n" "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */
"stp q17, q18, [%[out]], #32\n" "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */
"stp q19, q20, [%[out]], #32\n" "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */
"stp q21, q22, [%[out]]\n" "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */
"fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */
"fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */
"fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */
"fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"cbnz %w[flag_mask], 1f\n"
"str q15, [x0]\n" /* save outc00 */
"str q16, [x4]\n" /* save outc01 */
"str q17, [x1]\n" /* save outc10 */
"str q18, [x5]\n" /* save outc11 */
"str q19, [x2]\n" /* save outc20 */
"str q20, [x6]\n" /* save outc21 */
"str q21, [x3]\n" /* save outc30 */
"str q22, [x7]\n" /* save outc31 */
"b 2f\n"
"1:\n"
"str q15, [%[out]], #16 \n" /* save remain to pre_out */
"str q17, [%[out]], #16 \n" /* save remain to pre_out */
"str q19, [%[out]], #16 \n" /* save remain to pre_out */
"str q21, [%[out]], #16 \n" /* save remain to pre_out */
"str q16, [%[out]], #16 \n" /* save remain to pre_out */
"str q18, [%[out]], #16 \n" /* save remain to pre_out */
"str q20, [%[out]], #16 \n" /* save remain to pre_out */
"str q22, [%[out]], #16 \n" /* save remain to pre_out */
"2:\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1), :[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2), [inr3] "+r"(inr3), [inr2] "+r"(inr2), [inr3] "+r"(inr3),
[out]"+r"(out0) [out]"+r"(out0)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8) [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[vbias]"w" (vbias), [outl] "r" (outl_ptr),
[flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu)
: "cc", "memory", : "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7", "v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v9", "v10", "v11", "v15", "v8", "v9", "v10", "v11", "v15",
"v16","v17","v18","v19","v20","v21","v22" "v16","v17","v18","v19","v20","v21","v22",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"
); );
#else #else
asm volatile( asm volatile(
...@@ -355,183 +419,113 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -355,183 +419,113 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
"vmla.f32 q8, q5, q2 @ w8 * inr22\n" "vmla.f32 q8, q5, q2 @ w8 * inr22\n"
"vmla.f32 q9, q5, q3 @ w8 * inr23\n" "vmla.f32 q9, q5, q3 @ w8 * inr23\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n"
"ldr r4, [%[outl], #32] @ load bias addr to r4\n"
"vmla.f32 q14, q6, q0 @ w5 * inr24\n" "vmla.f32 q14, q6, q0 @ w5 * inr24\n"
"vmla.f32 q15, q6, q1 @ w5 * inr25\n" "vmla.f32 q15, q6, q1 @ w5 * inr25\n"
"vmla.f32 q10, q5, q0 @ w8 * inr24\n" "vmla.f32 q10, q5, q0 @ w8 * inr24\n"
"vmla.f32 q11, q5, q1 @ w8 * inr25\n" "vmla.f32 q11, q5, q1 @ w8 * inr25\n"
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
/* mul r3 with w6, w7, w8, get out r1 */ /* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q7, q2 @ w6 * inr30\n" "vmla.f32 q12, q7, q2 @ w6 * inr30\n"
"vmla.f32 q13, q7, q3 @ w6 * inr31\n" "vmla.f32 q13, q7, q3 @ w6 * inr31\n"
"vst1.32 {d16-d19}, [%[out0]]! @ save r00, r01, c0~c3\n"
"vmla.f32 q14, q7, q0 @ w6 * inr32\n" "vmla.f32 q14, q7, q0 @ w6 * inr32\n"
"vmla.f32 q15, q7, q1 @ w6 * inr33\n" "vmla.f32 q15, q7, q1 @ w6 * inr33\n"
"vst1.32 {d20-d23}, [%[out0]]! @ save r02, r03, c0~c3\n"
"vmla.f32 q12, q4, q3 @ w7 * inr31\n" "vmla.f32 q12, q4, q3 @ w7 * inr31\n"
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n"
"vld1.32 {d12-d13}, [r4] @ load bias\n"
"vmla.f32 q13, q4, q0 @ w7 * inr32\n" "vmla.f32 q13, q4, q0 @ w7 * inr32\n"
"vmla.f32 q14, q4, q1 @ w7 * inr33\n" "vmla.f32 q14, q4, q1 @ w7 * inr33\n"
"vmla.f32 q15, q4, q2 @ w7 * inr34\n" "vmla.f32 q15, q4, q2 @ w7 * inr34\n"
"ldr r0, [%[outl]] @ load outc00 to r0\n"
"vmla.f32 q12, q5, q0 @ w8 * inr32\n" "vmla.f32 q12, q5, q0 @ w8 * inr32\n"
"vmla.f32 q13, q5, q1 @ w8 * inr33\n" "vmla.f32 q13, q5, q1 @ w8 * inr33\n"
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n"
"vmla.f32 q14, q5, q2 @ w8 * inr34\n" "vmla.f32 q14, q5, q2 @ w8 * inr34\n"
"vmla.f32 q15, q5, q3 @ w8 * inr35\n" "vmla.f32 q15, q5, q3 @ w8 * inr35\n"
"vst1.32 {d24-d27}, [%[out0]]! @ save r10, r11, c0~c3\n" "ldr r1, [%[outl], #4] @ load outc10 to r1\n"
"vst1.32 {d28-d31}, [%[out0]]! @ save r12, r13, c0~c3\n" "vadd.f32 q8, q8, q6 @ r00 add bias\n"
"vadd.f32 q9, q9, q6 @ r01 add bias\n"
"vadd.f32 q10, q10, q6 @ r02 add bias\n"
"vadd.f32 q11, q11, q6 @ r03 add bias\n"
"ldr r2, [%[outl], #8] @ load outc20 to r2\n"
"vadd.f32 q12, q12, q6 @ r10 add bias\n"
"vadd.f32 q13, q13, q6 @ r11 add bias\n"
"vadd.f32 q14, q14, q6 @ r12 add bias\n"
"vadd.f32 q15, q15, q6 @ r13 add bias\n"
"ldr r3, [%[outl], #12] @ load outc30 to r3\n"
"vmov.u32 q7, #0 @ mov zero to q7\n"
"cmp r5, #0 @ cmp flag relu\n"
"beq 1f @ skip relu\n"
"vmax.f32 q8, q8, q7 @ r00 relu\n"
"vmax.f32 q9, q9, q7 @ r01 relu\n"
"vmax.f32 q10, q10, q7 @ r02 relu\n"
"vmax.f32 q11, q11, q7 @ r03 relu\n"
"vmax.f32 q12, q12, q7 @ r10 relu\n"
"vmax.f32 q13, q13, q7 @ r11 relu\n"
"vmax.f32 q14, q14, q7 @ r12 relu\n"
"vmax.f32 q15, q15, q7 @ r13 relu\n"
"1:\n"
"ldr r4, [%[outl], #16] @ load outc01 to r4\n"
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n"
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n"
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n"
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n"
"ldr r5, [%[outl], #20] @ load outc11 to r5\n"
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n"
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n"
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n"
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n"
"cmp %[flag_mask], #0 @ cmp flag mask\n"
"bne 2f\n"
"vst1.32 {d16-d17}, [r0] @ save outc00\n"
"vst1.32 {d18-d19}, [r1] @ save outc10\n"
"vst1.32 {d20-d21}, [r2] @ save outc20\n"
"vst1.32 {d22-d23}, [r3] @ save outc30\n"
"vst1.32 {d24-d25}, [r4] @ save outc01\n"
"vst1.32 {d26-d27}, [r5] @ save outc11\n"
"ldr r0, [%[outl], #24] @ load outc21 to r0\n"
"ldr r1, [%[outl], #28] @ load outc31 to r1\n"
"vst1.32 {d28-d29}, [r0] @ save outc21\n"
"vst1.32 {d30-d31}, [r1] @ save outc31\n"
"b 3f @ branch end\n"
"2: \n"
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n"
"3: \n"
: [r0] "+r"(inr0), [r1] "+r"(inr1), : [r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [r3] "+r"(inr3), [r2] "+r"(inr2), [r3] "+r"(inr3),
[out0] "+r"(out0), [wc0] "+r"(weight_c) [out0] "+r"(out0), [wc0] "+r"(weight_c)
: : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr)
: "cc", "memory", : "cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13","q14", "q15" "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5"
); );
#endif // __arch64__ #endif // __arch64__
float* out1 = pre_out;
if (flag_mask) {
c00 = outc00;
c01 = outc01;
c10 = outc10;
c11 = outc11;
c20 = outc20;
c21 = outc21;
c30 = outc30;
c31 = outc31;
outc00 = pre_out;
outc01 = pre_out + 4;
outc10 = pre_out + 8;
outc11 = pre_out + 12;
outc20 = pre_out + 16;
outc21 = pre_out + 20;
outc30 = pre_out + 24;
outc31 = pre_out + 28;
}
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[din]], #32\n" /* load input*/
"ldp q2, q3, [%[din]], #32\n" /* load input*/
"fadd v15.4s, v0.4s, %[vbias].4s\n" /* add bias */
"fadd v16.4s, v1.4s, %[vbias].4s\n" /* add bias */
"ldp q4, q5, [%[din]], #32\n" /* load input*/
"fadd v17.4s, v2.4s, %[vbias].4s\n" /* add bias */
"fadd v18.4s, v3.4s, %[vbias].4s\n" /* add bias */
"ldp q6, q7, [%[din]]\n" /* load input*/
"fadd v19.4s, v4.4s, %[vbias].4s\n" /* add bias */
"fadd v20.4s, v5.4s, %[vbias].4s\n" /* add bias */
"fadd v21.4s, v6.4s, %[vbias].4s\n" /* add bias */
"fadd v22.4s, v7.4s, %[vbias].4s\n" /* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"str q15, [%[outc00]], #16\n" /* save outc00*/
"str q16, [%[outc01]], #16\n" /* save outc01*/
"str q17, [%[outc10]], #16\n" /* save outc10*/
"str q18, [%[outc11]], #16\n" /* save outc11*/
"str q19, [%[outc20]], #16\n" /* save outc20*/
"str q20, [%[outc21]], #16\n" /* save outc21*/
"str q21, [%[outc30]], #16\n" /* save outc30*/
"str q22, [%[outc31]], #16\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v15", "v16","v17","v18","v19","v20","v21","v22"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]!\n" /* load input*/
"vld1.32 {d4-d7}, [%[din]]!\n" /* load input*/
"vadd.f32 q0, q0, %q[vbias]\n" /* add bias */
"vadd.f32 q1, q1, %q[vbias]\n" /* add bias */
"vld1.32 {d8-d11}, [%[din]]!\n" /* load input*/
"vadd.f32 q2, q2, %q[vbias]\n" /* add bias */
"vadd.f32 q3, q3, %q[vbias]\n" /* add bias */
"vld1.32 {d12-d15}, [%[din]]!\n" /* load input*/
"vadd.f32 q4, q4, %q[vbias]\n" /* add bias */
"vadd.f32 q5, q5, %q[vbias]\n" /* add bias */
"vadd.f32 q6, q6, %q[vbias]\n" /* add bias */
"vadd.f32 q7, q7, %q[vbias]\n" /* add bias */
/* transpose */
"vtrn.32 q0, q1\n" /* r0: q0: a0a1c0c1, q1: b0b1d0d1*/
"vtrn.32 q2, q3\n" /* r0: q2: a2a3c2c3, q3: b2b3d2d3*/
"vtrn.32 q4, q5\n" /* r1: q4: a0a1c0c1, q5: b0b1d0d1*/
"vtrn.32 q6, q7\n" /* r1: q6: a2a3c2c3, q7: b2b3d2d3*/
"vswp d1, d4\n" /* r0: q0: a0a1a2a3, q2: c0c1c2c3*/
"vswp d3, d6\n" /* r0: q1: b0b1b2b3, q3: d0d1d2d3*/
"vswp d9, d12\n" /* r1: q4: a0a1a2a3, q6: c0c1c2c3*/
"vswp d11, d14\n" /* r1: q5: b0b1b2b3, q7: d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q15, #0\n"
"vmax.f32 q0, q0, q15\n"
"vmax.f32 q1, q1, q15\n"
"vmax.f32 q2, q2, q15\n"
"vmax.f32 q3, q3, q15\n"
"vmax.f32 q4, q4, q15\n"
"vmax.f32 q5, q5, q15\n"
"vmax.f32 q6, q6, q15\n"
"vmax.f32 q7, q7, q15\n"
"0:\n"
"vst1.32 {d0-d1}, [%[outc00]]!\n" /* save outc00*/
"vst1.32 {d2-d3}, [%[outc10]]!\n" /* save outc10*/
"vst1.32 {d4-d5}, [%[outc20]]!\n" /* save outc20*/
"vst1.32 {d6-d7}, [%[outc30]]!\n" /* save outc30*/
"vst1.32 {d8-d9}, [%[outc01]]!\n" /* save outc01*/
"vst1.32 {d10-d11}, [%[outc11]]!\n" /* save outc11*/
"vst1.32 {d12-d13}, [%[outc21]]!\n" /* save outc21*/
"vst1.32 {d14-d15}, [%[outc31]]!\n" /* save outc31*/
:[outc00] "+r"(outc00), [outc01] "+r"(outc01),
[outc10] "+r"(outc10), [outc11] "+r"(outc11),
[outc20] "+r"(outc20), [outc21] "+r"(outc21),
[outc30] "+r"(outc30), [outc31] "+r"(outc31),
[din] "+r"(out1)
:[vbias]"w" (vbias), [flag_relu] "r"(flag_relu)
: "cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7", "q15"
);
#endif // __aarch64__
// clang-format on // clang-format on
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
outl[3] += 4;
outl[4] += 4;
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
if (flag_mask) { if (flag_mask) {
for (int i = 0; i < remain; ++i) { memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
c00[i] = pre_out[i]; memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
c01[i] = pre_out[i + 4]; memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float));
c10[i] = pre_out[i + 8]; memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float));
c11[i] = pre_out[i + 12]; memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float));
c20[i] = pre_out[i + 16]; memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float));
c21[i] = pre_out[i + 20]; memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float));
c30[i] = pre_out[i + 24]; memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float));
c31[i] = pre_out[i + 28];
}
} }
} }
} }
......
...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int win_round = wout_round + 2; const int win_round = wout_round + 2;
//! get h block //! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round //! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2 //! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2 //! hin_r_block = hout_r_block + 2
int hout_r_block = int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(llc_size - 2 * win_round * threads) / (win_round * threads * hout_c_block +
(win_round * threads + hout_c_block * wout_round * threads * 4); hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
...@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size); int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din; auto pre_din = tmp_din;
#endif #endif
prepack_input_nxw_c8_int8(din_batch, prepack_input_nxwc8_int8_dw(
pre_din, din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din; const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len; const int8_t* block_inr2 = block_inr1 + in_len;
......
...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, ...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const int win_round = wout_round * 2 /*stride*/ + 1; const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block //! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round //! llc_size = threads * win_round * hin_r_block * hout_c_block *
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2 //! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2 //! hin_r_block = hout_r_block + 2
int hout_r_block = int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(llc_size - 2 * win_round * threads) / (2 * win_round * threads * hout_c_block +
(2 * win_round * threads + hout_c_block * wout_round * threads * 4); hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
...@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, ...@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size); int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din; auto pre_din = tmp_din;
#endif #endif
prepack_input_nxw_c8_int8(din_batch, prepack_input_nxwc8_int8_dw(
pre_din, din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din; const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len; const int8_t* block_inr2 = block_inr1 + in_len;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h"
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/operators/op_params.h" #include "lite/operators/op_params.h"
...@@ -26,592 +27,749 @@ namespace lite { ...@@ -26,592 +27,749 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_depthwise_5x5s1_int8(int32_t* dout, #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din, const int8_t* din,
const int8_t* weights, const int8_t* weights,
const int* bias, const float* scale,
const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const int num, int num,
const int chin, int chin,
const int hin, int hin,
const int win, int win,
const int hout, int hout,
const int wout, int wout,
ARMContext* ctx, int padw,
PrecisionType out_type, int padh,
const float* scale); ARMContext* ctx) {
const int threads = ctx->threads();
void conv_depthwise_5x5_int8(const int8_t* din, int llc_size = ctx->llc_size() / 4;
int32_t* dout,
int num, const int hout_c_block = 8;
int chout, const int hout_r_kernel = 1;
int hout, const int wout_block = 4;
int wout, const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block;
int chin, const int win_round = wout_round + 4;
int hin,
int win, //! get h block
const int8_t* weights, //! llc_size = threads * win_round * hout_c_block * hin_r_block *
const int32_t* bias, //! sizeof(int8_t)
const operators::ConvParam& param, //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
ARMContext* ctx, //! win_round = wout_round + 4
PrecisionType out_type, //! hin_r_block = hout_r_block + 4
const float* scale) { int hout_r_block = (llc_size - 4 * win_round * hout_c_block * threads) /
int stride_h = param.strides[0]; (win_round * hout_c_block * threads +
bool flag_relu = param.fuse_relu; hout_c_block * wout_round * threads * 4);
bool flag_bias = param.bias != nullptr; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
// if (param.activation_param.has_active){ hout_r_block =
// if (param.activation_param.active == Active_relu || ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
// fabs(param.activation_param.negative_slope) > 1e-6f){ hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
// flag_relu = true;
// } const int hin_r_block = hout_r_block + 4;
// }
if (stride_h == 1) { auto tmp_work_space = ctx->workspace_data<int8_t>();
#ifdef __aarch64__ int8_t ptr_zero[win_round]; // NOLINT
conv_depthwise_5x5s1_int8(dout, memset(ptr_zero, 0, sizeof(int8_t) * win_round);
din, Dtype ptr_write[wout_round]; // NOLINT
weights,
bias, int in_len = win_round * hout_c_block;
flag_bias, int pre_in_size = hin_r_block * in_len;
flag_relu, pre_in_size = ROUNDUP(pre_in_size, 4);
num, int pre_out_size = hout_c_block * hout_r_block * wout_round;
chin,
hin, int8_t* tmp_din = tmp_work_space;
win,
hout,
wout,
ctx,
out_type,
scale);
#else
LOG(FATAL) << "5x5 dw conv armv7 has not impl";
#endif
}
}
/**
* \brief depthwise convolution, kernel size 5x5, stride 1, pad 1, with bias,
* width > 4
*/
// 2 line
#ifdef __aarch64__
template <typename Dtype>
inline void prefetch(const Dtype* din) {
#ifdef __aarch64__
asm volatile("PRFM PLDL1KEEP, [%[din]] \n" : : [din] "r"(din) : "memory");
#else
asm volatile("pld [%[din]] \n" : : [din] "r"(din) : "memory");
#endif
}
void conv_depthwise_5x5s1_int8(
int32_t* dout,
const int8_t* din,
const int8_t* weights,
const int32_t* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int chin,
const int hin,
const int win,
const int hout,
const int wout,
ARMContext* ctx,
PrecisionType od_type,
float const* scales) { /// scale_size = channel-out
// printf("5*5 multiply\n");
int size_in_channel = win * hin; int size_in_channel = win * hin;
int size_out_channel = wout * hout; int size_out_channel = wout * hout;
int w_stride = 5 * 5; int w_stride = 25; // kernel_w * kernel_h;
static int const stride_w = 1; int ws = -padw;
int const stride_h = stride_w; int we = ws + win_round;
int const chout = chin; int w_loop = wout_round / 4;
int const pad_w = 2; int chout = chin;
int const pad_h = pad_w;
int const wout_round = ((wout + 7) / 8) * 8;
int const win_round = wout_round * stride_w + 5 - 1;
int const hout_round = ((hout + 2) / 3) * 3;
int const hin_round = hout_round * stride_h + 5 - 1;
int const tile_h = hout_round / 3;
int const tile_w = wout_round / 8;
int const pre_in_size = hin_round * win_round;
int const pre_out_size = hout_round * wout_round;
int const pre_io_size = pre_in_size + pre_out_size * sizeof(int);
int const hs = -pad_h;
int const he = hs + hin_round;
int const ws = -pad_w;
int const we = ws + win_round;
// signed char* tmp_work_space = new signed char [1024*5];
signed char* tmp_work_space = ctx->workspace_data<signed char>();
signed char* ptr_zero = tmp_work_space;
int* ptr_write = reinterpret_cast<int*>(ptr_zero + win_round);
signed char* pre_data =
reinterpret_cast<signed char*>(ptr_write + wout_round);
memset(ptr_zero, 0, win_round * sizeof(signed char));
int out_row_stride = hout_c_block * wout_round;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
signed char const* din_batch = din + n * chin * size_in_channel; const int8_t* din_batch = din + n * chin * size_in_channel;
int* dout_batch = dout + n * chout * size_out_channel; int8_t* dout_batch = reinterpret_cast<int8_t*>(dout) +
n * chout * size_out_channel * sizeof(Dtype);
for (int h = 0; h < hout; h += hout_r_block) {
int h_kernel = hout_r_block;
if (h + hout_r_block > hout) {
h_kernel = hout - h;
}
int hs = h - padh;
int he = hs + h_kernel + 4;
// #pragma omp parallel for #pragma omp parallel for num_threads(threads)
for (int c = 0; c < chout; c++) { for (int c = 0; c < chout; c += hout_c_block) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
int const thno = omp_get_thread_num(); int8_t* pre_din =
tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4);
int32_t* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
#else #else
int const thno = 0; int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif #endif
signed char const* din_channel = din_batch + c * size_in_channel; prepack_input_nxwc8_int8_dw(
signed char* pre_din = pre_data + thno * pre_io_size; din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
int* pre_out = reinterpret_cast<int*>(pre_din + pre_in_size);
int* dout_ptr = pre_out; const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
prepack_input_nxw(din_channel, const int8_t* block_inr2 = block_inr1 + in_len;
pre_din, const int8_t* block_inr3 = block_inr2 + in_len;
c, const int8_t* block_inr4 = block_inr3 + in_len;
c + 1,
hs, const int8_t* weight_c = weights + c * w_stride;
he, float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0};
ws, if (flag_bias) {
we, bias_local[0] = bias[c];
1, bias_local[1] = bias[c + 1];
win, bias_local[2] = bias[c + 2];
hin, bias_local[3] = bias[c + 3];
ptr_zero); bias_local[4] = bias[c + 4];
bias_local[5] = bias[c + 5];
signed char const* wei_ptr = weights + c * w_stride; bias_local[6] = bias[c + 6];
int bias_val = flag_bias ? bias[c] : 0.f; bias_local[7] = bias[c + 7];
}
int8x8_t wr00 = vdup_n_s8(wei_ptr[0 * 5 + 0]); for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
int8x8_t wr01 = vdup_n_s8(wei_ptr[0 * 5 + 1]); int cnt = w_loop;
int8x8_t wr02 = vdup_n_s8(wei_ptr[0 * 5 + 2]); const int8_t* inr0 = block_inr0;
int8x8_t wr03 = vdup_n_s8(wei_ptr[0 * 5 + 3]); const int8_t* inr1 = block_inr1;
int8x8_t wr04 = vdup_n_s8(wei_ptr[0 * 5 + 4]); const int8_t* inr2 = block_inr2;
const int8_t* inr3 = block_inr3;
int8x8_t wr10 = vdup_n_s8(wei_ptr[1 * 5 + 0]); const int8_t* inr4 = block_inr4;
int8x8_t wr11 = vdup_n_s8(wei_ptr[1 * 5 + 1]);
int8x8_t wr12 = vdup_n_s8(wei_ptr[1 * 5 + 2]); int32_t* ptr_out0 = pre_out + hk * out_row_stride;
int8x8_t wr13 = vdup_n_s8(wei_ptr[1 * 5 + 3]); // clang-format off
int8x8_t wr14 = vdup_n_s8(wei_ptr[1 * 5 + 4]); #ifdef __aarch64__
auto wptr = weight_c;
int8x8_t wr20 = vdup_n_s8(wei_ptr[2 * 5 + 0]); asm volatile(
int8x8_t wr21 = vdup_n_s8(wei_ptr[2 * 5 + 1]); "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */
int8x8_t wr22 = vdup_n_s8(wei_ptr[2 * 5 + 2]); "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */
int8x8_t wr23 = vdup_n_s8(wei_ptr[2 * 5 + 3]); "1:\n"
int8x8_t wr24 = vdup_n_s8(wei_ptr[2 * 5 + 4]); /* in r0 */
"smull v20.8h, v0.8b, v8.8b\n" /* w0, int16, out0 */
int8x8_t wr30 = vdup_n_s8(wei_ptr[3 * 5 + 0]); "smull v21.8h, v1.8b, v8.8b\n" /* w0, int16, out1 */
int8x8_t wr31 = vdup_n_s8(wei_ptr[3 * 5 + 1]); "smull v22.8h, v2.8b, v8.8b\n" /* w0, int16, out2 */
int8x8_t wr32 = vdup_n_s8(wei_ptr[3 * 5 + 2]); "smull v23.8h, v3.8b, v8.8b\n" /* w0, int16, out3 */
int8x8_t wr33 = vdup_n_s8(wei_ptr[3 * 5 + 3]); "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]]\n" /* load r0 4-7 */
int8x8_t wr34 = vdup_n_s8(wei_ptr[3 * 5 + 4]); "smlal v20.8h, v1.8b, v9.8b\n" /* w1, int16, out0 */
"smlal v21.8h, v2.8b, v9.8b\n" /* w1, int16, out1 */
int8x8_t wr40 = vdup_n_s8(wei_ptr[4 * 5 + 0]); "smlal v22.8h, v3.8b, v9.8b\n" /* w1, int16, out2 */
int8x8_t wr41 = vdup_n_s8(wei_ptr[4 * 5 + 1]); "smlal v23.8h, v4.8b, v9.8b\n" /* w1, int16, out3 */
int8x8_t wr42 = vdup_n_s8(wei_ptr[4 * 5 + 2]); "sxtl v24.4s, v20.4h\n" /* mov to out0 low */
int8x8_t wr43 = vdup_n_s8(wei_ptr[4 * 5 + 3]); "sxtl2 v25.4s, v20.8h\n" /* mov to out0 hig */
int8x8_t wr44 = vdup_n_s8(wei_ptr[4 * 5 + 4]); "sxtl v26.4s, v21.4h\n" /* mov to out1 low */
"sxtl2 v27.4s, v21.8h\n" /* mov to out1 hig */
int* doutr0 = nullptr; "sxtl v28.4s, v22.4h\n" /* mov to out2 low */
int* doutr1 = nullptr; "sxtl2 v29.4s, v22.8h\n" /* mov to out2 hig */
int* doutr2 = nullptr; "sxtl v30.4s, v23.4h\n" /* mov to out3 low */
"sxtl2 v31.4s, v23.8h\n" /* mov to out3 hig */
signed char const* dr0 = pre_din; "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 4-7 */
signed char const* dr1 = dr0 + win_round;
signed char const* dr2 = dr1 + win_round; "smull v20.8h, v2.8b, v10.8b\n" /* w2, int16, out0 */
signed char const* dr3 = dr2 + win_round; "smull v21.8h, v3.8b, v10.8b\n" /* w2, int16, out1 */
signed char const* dr4 = dr3 + win_round; "smull v22.8h, v4.8b, v10.8b\n" /* w2, int16, out2 */
signed char const* dr5 = dr4 + win_round; "smull v23.8h, v5.8b, v10.8b\n" /* w2, int16, out3 */
signed char const* dr6 = dr5 + win_round; "smlal v20.8h, v3.8b, v11.8b\n" /* w3, int16, out0 */
"smlal v21.8h, v4.8b, v11.8b\n" /* w3, int16, out1 */
signed char const* din_ptr0 = nullptr; "smlal v22.8h, v5.8b, v11.8b\n" /* w3, int16, out2 */
signed char const* din_ptr1 = nullptr; "smlal v23.8h, v6.8b, v11.8b\n" /* w3, int16, out3 */
signed char const* din_ptr2 = nullptr; "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
signed char const* din_ptr3 = nullptr; "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
signed char const* din_ptr4 = nullptr; "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
signed char const* din_ptr5 = nullptr; "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
signed char const* din_ptr6 = nullptr; "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r1]], #32\n" /* load r1 0-3 */
"saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
for (int h = 0; h < tile_h; h++) { "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
// printf("c:%d h:%d\n", c, h); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
doutr0 = dout_ptr; "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
doutr1 = doutr0 + wout_round;
doutr2 = doutr1 + wout_round; "smull v20.8h, v4.8b, v12.8b\n" /* w4, int16, out0 */
"smull v21.8h, v5.8b, v12.8b\n" /* w4, int16, out1 */
din_ptr0 = dr0; "smull v22.8h, v6.8b, v12.8b\n" /* w4, int16, out2 */
din_ptr1 = dr1; "smull v23.8h, v7.8b, v12.8b\n" /* w4, int16, out3 */
din_ptr2 = dr2; "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r1]]\n" /* load r1 4-7 */
din_ptr3 = dr3; /* in r1 */
din_ptr4 = dr4; "smlal v20.8h, v0.8b, v13.8b\n" /* w5, int16, out0 */
din_ptr5 = dr5; "smlal v21.8h, v1.8b, v13.8b\n" /* w5, int16, out1 */
din_ptr6 = dr6; "smlal v22.8h, v2.8b, v13.8b\n" /* w5, int16, out2 */
"smlal v23.8h, v3.8b, v13.8b\n" /* w5, int16, out3 */
prefetch(doutr0); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
prefetch(doutr1); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
prefetch(doutr2); "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
prefetch(din_ptr0); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
prefetch(din_ptr1); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
prefetch(din_ptr2); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
prefetch(din_ptr3); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
prefetch(din_ptr4); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
prefetch(din_ptr5); "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 8-11 */
prefetch(din_ptr6);
"smull v20.8h, v1.8b, v14.8b\n" /* w6, int16, out0 */
for (int j = 0; j < tile_w; ++j) { "smull v21.8h, v2.8b, v14.8b\n" /* w6, int16, out1 */
// printf("j:%d\n", j); "smull v22.8h, v3.8b, v14.8b\n" /* w6, int16, out2 */
int32x4_t voutr00 = vdupq_n_s32(bias_val); "smull v23.8h, v4.8b, v14.8b\n" /* w6, int16, out3 */
int32x4_t voutr01 = vdupq_n_s32(bias_val); "smlal v20.8h, v2.8b, v15.8b\n" /* w7, int16, out0 */
int32x4_t voutr10 = vdupq_n_s32(bias_val); "smlal v21.8h, v3.8b, v15.8b\n" /* w7, int16, out1 */
int32x4_t voutr11 = vdupq_n_s32(bias_val); "smlal v22.8h, v4.8b, v15.8b\n" /* w7, int16, out2 */
int32x4_t voutr20 = vdupq_n_s32(bias_val); "smlal v23.8h, v5.8b, v15.8b\n" /* w7, int16, out3 */
int32x4_t voutr21 = vdupq_n_s32(bias_val); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
"saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
// din data "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
int8x8_t vinr00 = vld1_s8(din_ptr0 + 0); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
int8x8_t vinr01 = vld1_s8(din_ptr0 + 8); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
int8x8_t vinr10 = vld1_s8(din_ptr1 + 0); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
int8x8_t vinr11 = vld1_s8(din_ptr1 + 8); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
int8x8_t vinr20 = vld1_s8(din_ptr2 + 0); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
int8x8_t vinr21 = vld1_s8(din_ptr2 + 8);
int8x8_t vinr30 = vld1_s8(din_ptr3 + 0); "smull v20.8h, v3.8b, v16.8b\n" /* w8, int16, out0 */
int8x8_t vinr31 = vld1_s8(din_ptr3 + 8); "smull v21.8h, v4.8b, v16.8b\n" /* w8, int16, out1 */
int8x8_t vinr40 = vld1_s8(din_ptr4 + 0); "smull v22.8h, v5.8b, v16.8b\n" /* w8, int16, out2 */
int8x8_t vinr41 = vld1_s8(din_ptr4 + 8); "smull v23.8h, v6.8b, v16.8b\n" /* w8, int16, out3 */
int8x8_t vinr50 = vld1_s8(din_ptr5 + 0); "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r2]], #32\n" /* load r2 0-3 */
int8x8_t vinr51 = vld1_s8(din_ptr5 + 8); "smlal v20.8h, v4.8b, v17.8b\n" /* w9, int16, out0 */
int8x8_t vinr60 = vld1_s8(din_ptr6 + 0); "smlal v21.8h, v5.8b, v17.8b\n" /* w9, int16, out1 */
int8x8_t vinr61 = vld1_s8(din_ptr6 + 8); "smlal v22.8h, v6.8b, v17.8b\n" /* w9, int16, out2 */
"smlal v23.8h, v7.8b, v17.8b\n" /* w9, int16, out3 */
/// the first row "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
// r0 "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
int8x8_t vtmp1 = vext_s8(vinr00, vinr01, 1); // 12345678 "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
int8x8_t vtmp2 = vext_s8(vinr00, vinr01, 2); // 2345678 "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
int8x8_t vtmp3 = vext_s8(vinr00, vinr01, 3); // 345678 "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r2]]\n" /* load r2 4-7 */
int8x8_t vtmp4 = vext_s8(vinr00, vinr01, 4); // 45678 "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
"saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
int16x8_t tvoutr0 = vmull_s8(vinr00, wr00); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr01); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); /* in r2 */
tvoutr0 = vmull_s8(vtmp2, wr02); "smull v20.8h, v0.8b, v18.8b\n" /* w10, int16, out0 */
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr03); "smull v21.8h, v1.8b, v18.8b\n" /* w10, int16, out1 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smull v22.8h, v2.8b, v18.8b\n" /* w10, int16, out2 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smull v23.8h, v3.8b, v18.8b\n" /* w10, int16, out3 */
tvoutr0 = vmull_s8(vtmp4, wr04); "smlal v20.8h, v1.8b, v19.8b\n" /* w11, int16, out0 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smlal v21.8h, v2.8b, v19.8b\n" /* w11, int16, out1 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smlal v22.8h, v3.8b, v19.8b\n" /* w11, int16, out2 */
"smlal v23.8h, v4.8b, v19.8b\n" /* w11, int16, out3 */
// r1
vtmp1 = vext_s8(vinr10, vinr11, 1); // 12345678 "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 12-15 */
vtmp2 = vext_s8(vinr10, vinr11, 2); // 2345678 "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
vtmp3 = vext_s8(vinr10, vinr11, 3); // 345678 "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
vtmp4 = vext_s8(vinr10, vinr11, 4); // 45678 "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
"saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
tvoutr0 = vmull_s8(vinr10, wr10); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr11); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr0 = vmull_s8(vtmp2, wr12);
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr13); "smull v20.8h, v2.8b, v8.8b\n" /* w12, int16, out0 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smull v21.8h, v3.8b, v8.8b\n" /* w12, int16, out1 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smull v22.8h, v4.8b, v8.8b\n" /* w12, int16, out2 */
tvoutr0 = vmull_s8(vtmp4, wr14); "smull v23.8h, v5.8b, v8.8b\n" /* w12, int16, out3 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smlal v20.8h, v3.8b, v9.8b\n" /* w13, int16, out0 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smlal v21.8h, v4.8b, v9.8b\n" /* w13, int16, out1 */
"smlal v22.8h, v5.8b, v9.8b\n" /* w13, int16, out2 */
int16x8_t tvoutr1 = vmull_s8(vinr10, wr00); "smlal v23.8h, v6.8b, v9.8b\n" /* w13, int16, out3 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr01); "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r3]], #32\n" /* load r3 0-3 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
tvoutr1 = vmull_s8(vtmp2, wr02); "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr03); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
tvoutr1 = vmull_s8(vtmp4, wr04); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "smull v20.8h, v4.8b, v10.8b\n" /* w14, int16, out0 */
"smull v21.8h, v5.8b, v10.8b\n" /* w14, int16, out1 */
// r2 "smull v22.8h, v6.8b, v10.8b\n" /* w14, int16, out2 */
vtmp1 = vext_s8(vinr20, vinr21, 1); // 12345678 "smull v23.8h, v7.8b, v10.8b\n" /* w14, int16, out3 */
vtmp2 = vext_s8(vinr20, vinr21, 2); // 2345678 "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r3]]\n" /* load r3 4-7 */
vtmp3 = vext_s8(vinr20, vinr21, 3); // 345678 /* in r3 */
vtmp4 = vext_s8(vinr20, vinr21, 4); // 45678 "smlal v20.8h, v0.8b, v11.8b\n" /* w15, int16, out0 */
"smlal v21.8h, v1.8b, v11.8b\n" /* w15, int16, out1 */
tvoutr0 = vmull_s8(vinr20, wr20); "smlal v22.8h, v2.8b, v11.8b\n" /* w15, int16, out2 */
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr21); "smlal v23.8h, v3.8b, v11.8b\n" /* w15, int16, out3 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 16-19 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
tvoutr0 = vmull_s8(vtmp2, wr22); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr23); "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
tvoutr0 = vmull_s8(vtmp4, wr24); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr1 = vmull_s8(vinr20, wr10); "smull v20.8h, v1.8b, v12.8b\n" /* w16, int16, out0 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr11); "smull v21.8h, v2.8b, v12.8b\n" /* w16, int16, out1 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "smull v22.8h, v3.8b, v12.8b\n" /* w16, int16, out2 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "smull v23.8h, v4.8b, v12.8b\n" /* w16, int16, out3 */
tvoutr1 = vmull_s8(vtmp2, wr12); "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 20-23 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr13); "smlal v20.8h, v2.8b, v13.8b\n" /* w17, int16, out0 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "smlal v21.8h, v3.8b, v13.8b\n" /* w17, int16, out1 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "smlal v22.8h, v4.8b, v13.8b\n" /* w17, int16, out2 */
tvoutr1 = vmull_s8(vtmp4, wr14); "smlal v23.8h, v5.8b, v13.8b\n" /* w17, int16, out3 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
int16x8_t tvoutr2 = vmull_s8(vinr20, wr00); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr01); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
tvoutr2 = vmull_s8(vtmp2, wr02); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr03);
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "smull v20.8h, v3.8b, v14.8b\n" /* w18, int16, out0 */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "smull v21.8h, v4.8b, v14.8b\n" /* w18, int16, out1 */
tvoutr2 = vmull_s8(vtmp4, wr04); "smull v22.8h, v5.8b, v14.8b\n" /* w18, int16, out2 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "smull v23.8h, v6.8b, v14.8b\n" /* w18, int16, out3 */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r4]], #32\n" /* load r4 0-3 */
"smlal v20.8h, v4.8b, v15.8b\n" /* w19, int16, out0 */
// r3 "smlal v21.8h, v5.8b, v15.8b\n" /* w19, int16, out1 */
vtmp1 = vext_s8(vinr30, vinr31, 1); // 12345678 "smlal v22.8h, v6.8b, v15.8b\n" /* w19, int16, out2 */
vtmp2 = vext_s8(vinr30, vinr31, 2); // 2345678 "smlal v23.8h, v7.8b, v15.8b\n" /* w19, int16, out3 */
vtmp3 = vext_s8(vinr30, vinr31, 3); // 345678 "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
vtmp4 = vext_s8(vinr30, vinr31, 4); // 45678 "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
"saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
tvoutr0 = vmull_s8(vinr30, wr30); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr31); "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r4]]\n" /* load r4 4-7 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
tvoutr0 = vmull_s8(vtmp2, wr32); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr33); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0));
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); /* in r4 */
tvoutr0 = vmull_s8(vtmp4, wr34); "smull v20.8h, v0.8b, v16.8b\n" /* w20, int16, out0 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smull v21.8h, v1.8b, v16.8b\n" /* w20, int16, out1 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smull v22.8h, v2.8b, v16.8b\n" /* w20, int16, out2 */
"smull v23.8h, v3.8b, v16.8b\n" /* w20, int16, out3 */
tvoutr1 = vmull_s8(vinr30, wr20); "smlal v20.8h, v1.8b, v17.8b\n" /* w21, int16, out0 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr21); "smlal v21.8h, v2.8b, v17.8b\n" /* w21, int16, out1 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "smlal v22.8h, v3.8b, v17.8b\n" /* w21, int16, out2 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "smlal v23.8h, v4.8b, v17.8b\n" /* w21, int16, out3 */
tvoutr1 = vmull_s8(vtmp2, wr22); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr23); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
tvoutr1 = vmull_s8(vtmp4, wr24); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
"saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr2 = vmull_s8(vinr30, wr10); "ld1 {v12.8b}, [%[wc]], #8\n" /* load wc 24 */
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr11); "smull v20.8h, v2.8b, v18.8b\n" /* w22, int16, out0 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "smull v21.8h, v3.8b, v18.8b\n" /* w22, int16, out1 */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "smull v22.8h, v4.8b, v18.8b\n" /* w22, int16, out2 */
tvoutr2 = vmull_s8(vtmp2, wr12); "smull v23.8h, v5.8b, v18.8b\n" /* w22, int16, out3 */
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr13); "smlal v20.8h, v3.8b, v19.8b\n" /* w23, int16, out0 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "smlal v21.8h, v4.8b, v19.8b\n" /* w23, int16, out1 */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "smlal v22.8h, v5.8b, v19.8b\n" /* w23, int16, out2 */
tvoutr2 = vmull_s8(vtmp4, wr14); "smlal v23.8h, v6.8b, v19.8b\n" /* w23, int16, out3 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "sub %[wc], %[wc], #200 \n"
"saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
// r4 "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
vtmp1 = vext_s8(vinr40, vinr41, 1); // 12345678 "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
vtmp2 = vext_s8(vinr40, vinr41, 2); // 2345678 "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
vtmp3 = vext_s8(vinr40, vinr41, 3); // 345678 "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
vtmp4 = vext_s8(vinr40, vinr41, 4); // 45678 "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
"saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
tvoutr0 = vmull_s8(vinr40, wr40); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr41);
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "smull v20.8h, v4.8b, v12.8b\n" /* w24, int16, out0 */
tvoutr0 = vmull_s8(vtmp2, wr42); "smull v21.8h, v5.8b, v12.8b\n" /* w24, int16, out1 */
tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr43); "smull v22.8h, v6.8b, v12.8b\n" /* w24, int16, out2 */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "smull v23.8h, v7.8b, v12.8b\n" /* w24, int16, out3 */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */
tvoutr0 = vmull_s8(vtmp4, wr44); "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */
voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */
voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */
"stp q24, q25, [%[ptr_out0]], #32\n"
tvoutr1 = vmull_s8(vinr40, wr30); "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr31); "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "stp q26, q27, [%[ptr_out0]], #32\n"
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */
tvoutr1 = vmull_s8(vtmp2, wr32); "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr33); "subs %w[cnt], %w[cnt], #1\n"
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "stp q28, q29, [%[ptr_out0]], #32\n"
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "stp q30, q31, [%[ptr_out0]], #32\n"
tvoutr1 = vmull_s8(vtmp4, wr34); "bne 1b\n"
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); : [cnt] "+r"(cnt),
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); [r0] "+r"(inr0),
[r1] "+r"(inr1),
tvoutr2 = vmull_s8(vinr40, wr20); [r2] "+r"(inr2),
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr21); [r3] "+r"(inr3),
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); [r4] "+r"(inr4),
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); [wc] "+r"(wptr),
tvoutr2 = vmull_s8(vtmp2, wr22); [ptr_out0] "+r"(ptr_out0)
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr23); :
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); : "cc","memory",
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "v0","v1","v2","v3","v4","v5","v6","v7",
tvoutr2 = vmull_s8(vtmp4, wr24); "v8","v9","v10","v11","v12","v13",
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "v14","v15","v16","v17","v18","v19",
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "v20","v21","v22","v23","v24","v25",
"v26","v27","v28","v29","v30","v31"
// r5 );
vtmp1 = vext_s8(vinr50, vinr51, 1); // 12345678 #else
vtmp2 = vext_s8(vinr50, vinr51, 2); // 2345678 auto wptr = weight_c;
vtmp3 = vext_s8(vinr50, vinr51, 3); // 345678 asm volatile(
vtmp4 = vext_s8(vinr50, vinr51, 4); // 45678 "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */
tvoutr1 = vmull_s8(vinr50, wr40); "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr41); "1:\n"
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); /* inr0 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "vmull.s8 q4, d0, d6\n" /* int16, out0 */
tvoutr1 = vmull_s8(vtmp2, wr42); "vmull.s8 q5, d1, d6\n" /* int16, out1 */
tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr43); "vmull.s8 q6, d2, d6\n" /* int16, out2 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "vmull.s8 q7, d3, d6\n" /* int16, out3 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "vmlal.s8 q4, d1, d7\n" /* int16, out0 */
tvoutr1 = vmull_s8(vtmp4, wr44); "vmlal.s8 q5, d2, d7\n" /* int16, out1 */
voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); "vmlal.s8 q6, d3, d7\n" /* int16, out2 */
voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); "vmlal.s8 q7, d4, d7\n" /* int16, out3 */
"vmovl.s16 q8, d8\n" /* mov to out0 low */
tvoutr2 = vmull_s8(vinr50, wr30); "vmovl.s16 q9, d9\n" /* mov to out0 hig */
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr31); "vmovl.s16 q10, d10\n" /* mov to out1 low */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "vmovl.s16 q11, d11\n" /* mov to out1 hig */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */
tvoutr2 = vmull_s8(vtmp2, wr32); "vmovl.s16 q12, d12\n" /* mov to out2 low */
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr33); "vmovl.s16 q13, d13\n" /* mov to out2 hig */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "vmovl.s16 q14, d14\n" /* mov to out3 low */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "vmovl.s16 q15, d15\n" /* mov to out3 hig */
tvoutr2 = vmull_s8(vtmp4, wr34); "vld1.32 {d0-d1}, [%[r0]]\n" /* load r0, 6-7 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2));
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "vmull.s8 q4, d2, d6\n" /* w2, int16, out0 */
"vmull.s8 q5, d3, d6\n" /* w2, int16, out1 */
// r6 "vmull.s8 q6, d4, d6\n" /* w2, int16, out2 */
vtmp1 = vext_s8(vinr60, vinr61, 1); // 12345678 "vmull.s8 q7, d5, d6\n" /* w2, int16, out3 */
vtmp2 = vext_s8(vinr60, vinr61, 2); // 2345678 "vmlal.s8 q4, d3, d7\n" /* w3, int16, out0 */
vtmp3 = vext_s8(vinr60, vinr61, 3); // 345678 "vmlal.s8 q5, d4, d7\n" /* w3, int16, out1 */
vtmp4 = vext_s8(vinr60, vinr61, 4); // 45678 "vmlal.s8 q6, d5, d7\n" /* w3, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w3, int16, out3 */
tvoutr2 = vmull_s8(vinr60, wr40); "vaddw.s16 q8, q8, d8\n" /* add to out0 low */
tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr41); "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "vaddw.s16 q10, q10, d10\n" /* add to out1 low */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
tvoutr2 = vmull_s8(vtmp2, wr42); "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */
tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr43); "sub %[r0], %[r0], #16\n" /* r0 = r0 - 16 */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "vaddw.s16 q12, q12, d12\n" /* add to out2 low */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
tvoutr2 = vmull_s8(vtmp4, wr44); "vaddw.s16 q14, q14, d14\n" /* add to out3 low */
voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2));
"vmull.s8 q4, d4, d6\n" /* w4, int16, out0 */
/// data shift 8 bytes "vmull.s8 q5, d5, d6\n" /* w4, int16, out1 */
din_ptr0 += 8; "vmull.s8 q6, d0, d6\n" /* w4, int16, out2 */
din_ptr1 += 8; "vmull.s8 q7, d1, d6\n" /* w4, int16, out3 */
din_ptr2 += 8; "vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */
din_ptr3 += 8; /* inr1 */
din_ptr4 += 8; "vmlal.s8 q4, d0, d7\n" /* w5, int16, out0 */
din_ptr5 += 8; "vmlal.s8 q5, d1, d7\n" /* w5, int16, out1 */
din_ptr6 += 8; "vmlal.s8 q6, d2, d7\n" /* w5, int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* w5, int16, out3 */
/// store "vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4-5 */
vst1q_s32(doutr0, voutr00); "vaddw.s16 q8, q8, d8\n" /* add to out0 low */
vst1q_s32(doutr1, voutr10); "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
vst1q_s32(doutr2, voutr20); "vaddw.s16 q10, q10, d10\n" /* add to out1 low */
doutr0 += 4; "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
doutr1 += 4; "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */
doutr2 += 4; "vaddw.s16 q12, q12, d12\n" /* add to out2 low */
vst1q_s32(doutr0, voutr01); "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
vst1q_s32(doutr1, voutr11); "vaddw.s16 q14, q14, d14\n" /* add to out3 low */
vst1q_s32(doutr2, voutr21); "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
doutr0 += 4;
doutr1 += 4; "vmull.s8 q4, d1, d6\n" /* w6, int16, out0 */
doutr2 += 4; "vmull.s8 q5, d2, d6\n" /* w6, int16, out1 */
} /// end of tile_w "vmull.s8 q6, d3, d6\n" /* w6, int16, out2 */
"vmull.s8 q7, d4, d6\n" /* w6, int16, out3 */
dr0 = dr3; "vld1.32 {d0-d1}, [%[r1]]\n" /* load r1, 6-7 */
dr1 = dr4; "vmlal.s8 q4, d2, d7\n" /* w7, int16, out0 */
dr2 = dr5; "vmlal.s8 q5, d3, d7\n" /* w7, int16, out1 */
dr3 = dr6; "vmlal.s8 q6, d4, d7\n" /* w7, int16, out2 */
dr4 = dr3 + win_round; "vmlal.s8 q7, d5, d7\n" /* w7, int16, out3 */
dr5 = dr4 + win_round; "sub %[r1], %[r1], #16\n" /* r0 = r0 - 16 */
dr6 = dr5 + win_round; "vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
dout_ptr = dout_ptr + 3 * wout_round; "vaddw.s16 q10, q10, d10\n" /* add to out1 low */
} /// end of tile_h "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w8-w9 */
if (scales == 0) { "vaddw.s16 q12, q12, d12\n" /* add to out2 low */
write_to_output_numc(pre_out, "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
dout_batch, "vaddw.s16 q14, q14, d14\n" /* add to out3 low */
1, "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
hout_round,
c, "vmull.s8 q4, d3, d6\n" /* w8, int16, out0 */
c + 1, "vmull.s8 q5, d4, d6\n" /* w8, int16, out1 */
0, "vmull.s8 q6, d5, d6\n" /* w8, int16, out2 */
hout, "vmull.s8 q7, d0, d6\n" /* w8, int16, out3 */
0, "vmlal.s8 q4, d4, d7\n" /* w9, int16, out0 */
wout_round, "vmlal.s8 q5, d5, d7\n" /* w9, int16, out1 */
chout, "vmlal.s8 q6, d0, d7\n" /* w9, int16, out2 */
hout, "vmlal.s8 q7, d1, d7\n" /* w9, int16, out3 */
wout, "vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */
flag_relu, "vaddw.s16 q8, q8, d8\n" /* add to out0 low */
ptr_write); "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
} else if (od_type == PRECISION(kFloat)) { "vaddw.s16 q10, q10, d10\n" /* add to out1 low */
write2_to_output_numc(pre_out, "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
reinterpret_cast<float*>(dout_batch), "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w10-w11 */
1, "vaddw.s16 q12, q12, d12\n" /* add to out2 low */
hout_round, "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
c, "vaddw.s16 q14, q14, d14\n" /* add to out3 low */
c + 1, "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
0, "vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */
hout,
0, /* inr2 */
wout_round, "vmull.s8 q4, d0, d6\n" /* w10, int16, out0 */
chout, "vmull.s8 q5, d1, d6\n" /* w10, int16, out1 */
hout, "vmull.s8 q6, d2, d6\n" /* w10, int16, out2 */
wout, "vmull.s8 q7, d3, d6\n" /* w10, int16, out3 */
flag_relu, "vmlal.s8 q4, d1, d7\n" /* w11, int16, out0 */
reinterpret_cast<float*>(ptr_write), "vmlal.s8 q5, d2, d7\n" /* w11, int16, out1 */
scales); "vmlal.s8 q6, d3, d7\n" /* w11, int16, out2 */
} else if (od_type == PRECISION(kInt8)) { "vmlal.s8 q7, d4, d7\n" /* w11, int16, out3 */
write2_to_output_numc(pre_out, "vaddw.s16 q8, q8, d8\n" /* add to out0 low */
reinterpret_cast<signed char*>(dout_batch), "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
1, "vaddw.s16 q10, q10, d10\n" /* add to out1 low */
hout_round, "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
c, "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w12-w13 */
c + 1, "vaddw.s16 q12, q12, d12\n" /* add to out2 low */
0, "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
hout, "vaddw.s16 q14, q14, d14\n" /* add to out3 low */
0, "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
wout_round, "vld1.32 {d0-d1}, [%[r2]]\n" /* load r2, 6-7 */
chout,
hout, "vmull.s8 q4, d2, d6\n" /* w12, int16, out0 */
wout, "vmull.s8 q5, d3, d6\n" /* w12, int16, out1 */
flag_relu, "vmull.s8 q6, d4, d6\n" /* w12, int16, out2 */
reinterpret_cast<signed char*>(ptr_write), "vmull.s8 q7, d5, d6\n" /* w12, int16, out3 */
scales); "vmlal.s8 q4, d3, d7\n" /* w13, int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* w13, int16, out1 */
"vmlal.s8 q6, d5, d7\n" /* w13, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w13, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w14-w15 */
"sub %[r2], %[r2], #16\n" /* r2 = r2 - 16 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d4, d6\n" /* w14, int16, out0 */
"vmull.s8 q5, d5, d6\n" /* w14, int16, out1 */
"vmull.s8 q6, d0, d6\n" /* w14, int16, out2 */
"vmull.s8 q7, d1, d6\n" /* w14, int16, out3 */
"vld1.32 {d0-d3}, [%[r3]]!\n" /* load r3, 0-3 */
/* inr3 */
"vmlal.s8 q4, d0, d7\n" /* w15, int16, out0 */
"vmlal.s8 q5, d1, d7\n" /* w15, int16, out1 */
"vmlal.s8 q6, d2, d7\n" /* w15, int16, out2 */
"vmlal.s8 q7, d3, d7\n" /* w15, int16, out3 */
"vld1.32 {d4-d5}, [%[r3]]!\n" /* load r3, 4-5 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w16-w17 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vmull.s8 q4, d1, d6\n" /* w16, int16, out0 */
"vmull.s8 q5, d2, d6\n" /* w16, int16, out1 */
"vmull.s8 q6, d3, d6\n" /* w16, int16, out2 */
"vmull.s8 q7, d4, d6\n" /* w16, int16, out3 */
"vld1.32 {d0-d1}, [%[r3]]\n" /* load r3, 6-7 */
"vmlal.s8 q4, d2, d7\n" /* w17, int16, out0 */
"vmlal.s8 q5, d3, d7\n" /* w17, int16, out1 */
"vmlal.s8 q6, d4, d7\n" /* w17, int16, out2 */
"vmlal.s8 q7, d5, d7\n" /* w17, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w18-w19 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"sub %[r3], %[r3], #16\n" /* r3 = r3 - 16 */
"vmull.s8 q4, d3, d6\n" /* w18, int16, out0 */
"vmull.s8 q5, d4, d6\n" /* w18, int16, out1 */
"vmull.s8 q6, d5, d6\n" /* w18, int16, out2 */
"vmull.s8 q7, d0, d6\n" /* w18, int16, out3 */
"vmlal.s8 q4, d4, d7\n" /* w19, int16, out0 */
"vmlal.s8 q5, d5, d7\n" /* w19, int16, out1 */
"vmlal.s8 q6, d0, d7\n" /* w19, int16, out2 */
"vmlal.s8 q7, d1, d7\n" /* w19, int16, out3 */
"vld1.32 {d0-d3}, [%[r4]]!\n" /* load r4, 0-3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w20-w21 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d4-d5}, [%[r4]]!\n" /* load r4, 4-5 */
/* inr4 */
"vmull.s8 q4, d0, d6\n" /* w20, int16, out0 */
"vmull.s8 q5, d1, d6\n" /* w20, int16, out1 */
"vmull.s8 q6, d2, d6\n" /* w20, int16, out2 */
"vmull.s8 q7, d3, d6\n" /* w20, int16, out3 */
"vmlal.s8 q4, d1, d7\n" /* w21, int16, out0 */
"vmlal.s8 q5, d2, d7\n" /* w21, int16, out1 */
"vmlal.s8 q6, d3, d7\n" /* w21, int16, out2 */
"vmlal.s8 q7, d4, d7\n" /* w21, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w22-w23 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"vld1.32 {d0-d1}, [%[r4]]\n" /* load r4, 5-6 */
"vmull.s8 q4, d2, d6\n" /* w22, int16, out0 */
"vmull.s8 q5, d3, d6\n" /* w22, int16, out1 */
"vmull.s8 q6, d4, d6\n" /* w22, int16, out2 */
"vmull.s8 q7, d5, d6\n" /* w22, int16, out3 */
"vmlal.s8 q4, d3, d7\n" /* w23, int16, out0 */
"vmlal.s8 q5, d4, d7\n" /* w23, int16, out1 */
"vmlal.s8 q6, d5, d7\n" /* w23, int16, out2 */
"vmlal.s8 q7, d0, d7\n" /* w23, int16, out3 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vld1.32 {d6}, [%[wptr]]!\n" /* load w24 */
"sub %[r4], %[r4], #16\n" /* r4 = r4 - 16 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"sub %[wptr], %[wptr], #200 \n" /* wptr = wptr - 200 */
"vmull.s8 q4, d4, d6\n" /* w22, int16, out0 */
"vmull.s8 q5, d5, d6\n" /* w22, int16, out1 */
"vmull.s8 q6, d0, d6\n" /* w22, int16, out2 */
"vmull.s8 q7, d1, d6\n" /* w22, int16, out3 */
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */
"vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */
"vaddw.s16 q8, q8, d8\n" /* add to out0 low */
"vaddw.s16 q9, q9, d9\n" /* add to out0 hig */
"vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 0-3 */
"vaddw.s16 q10, q10, d10\n" /* add to out1 low */
"vaddw.s16 q11, q11, d11\n" /* add to out1 hig */
"vst1.32 {d16-d19}, [%[ptr_out0]]!\n"/* store out0 */
"vaddw.s16 q12, q12, d12\n" /* add to out2 low */
"vaddw.s16 q13, q13, d13\n" /* add to out2 hig */
"vst1.32 {d20-d23}, [%[ptr_out0]]!\n"/*store out1 */
"vaddw.s16 q14, q14, d14\n" /* add to out3 low */
"vaddw.s16 q15, q15, d15\n" /* add to out3 hig */
"subs %[cnt], #1\n" /* cnt = cnt - 1 */
"vst1.32 {d24-d27}, [%[ptr_out0]]!\n"/* store out2 */
"vst1.32 {d28-d31}, [%[ptr_out0]]!\n"/* store out3 */
"bne 1b\n" /* branch main loop */
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[ptr_out0] "+r"(ptr_out0),
[wptr] "+r"(wptr)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
// clang-format on
int32_t* ptr_tmp = ptr_out0 - w_loop * 32;
block_inr0 = block_inr1;
block_inr1 = block_inr2;
block_inr2 = block_inr3;
block_inr3 = block_inr4;
block_inr4 = block_inr3 + in_len;
}
write_int32_nchwc8_to_nchw<Dtype>(pre_out,
reinterpret_cast<Dtype*>(dout_batch),
c,
c + hout_c_block,
h,
h + h_kernel,
0,
wout_round,
chout,
hout,
wout,
flag_relu,
bias_local,
flag_bias,
ptr_write,
scale + c);
} }
// else if (od_type == AK_INT32) { }
// write2_to_output_numc(pre_out, (int*)dout_batch, 1, hout_round, c, }
// c+1,
// 0, hout, 0, wout_round, chout, hout, wout, flag_relu,
// (int*)ptr_write, scales);
// }
} /// end of chout
} /// end of batch num
} }
#endif // __aarch64__ template void conv_depthwise_5x5s1_int8<int8_t>(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template void conv_depthwise_5x5s1_int8<float>(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din,
} }
} }
inline void prepack_input_nxw_c8_int8(const int8_t* din, inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
int8_t* dout, int8_t* dout,
int cs, int cs,
int ce, int hs,
int hs, int he,
int he, int ws,
int ws, int we,
int we, int channel,
int channel, int width,
int width, int height) {
int height) {
int n = he - hs; int n = he - hs;
if (n <= 0) { if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0"; LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero";
return;
} }
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws; int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we; int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0; int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t); int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w)); int size_c = width * height;
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c, int valid_cnt = valid_w >> 3;
ptr_c + size_w, int remain = valid_w & 7;
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int8_t zero_ptr[size_w * 2]; // NOLINT int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2); memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8; for (int h = hs; h < he; ++h) {
int remain = size_w - loop * 8; const int8_t* ptr_c0 = din + h * width + cs * size_c;
const int8_t* ptr_c1 = ptr_c0 + size_c;
for (int c = cs; c < ce; c += 8) { const int8_t* ptr_c2 = ptr_c1 + size_c;
auto din_c = din + c * size_channel_in; const int8_t* ptr_c3 = ptr_c2 + size_c;
for (int j = 0; j < 8; ++j) { const int8_t* ptr_c4 = ptr_c3 + size_c;
ptr_r[j] = ptr_c_ori[j]; const int8_t* ptr_c5 = ptr_c4 + size_c;
} const int8_t* ptr_c6 = ptr_c5 + size_c;
//! valid channel const int8_t* ptr_c7 = ptr_c6 + size_c;
if (c + 8 > channel) { if (h < 0 || h >= height) {
switch (c + 8 - channel) { memset(dout, 0, 8 * size_w * sizeof(int8_t));
dout += size_w * 8;
continue;
} else if (cs + 8 > channel) {
switch (cs + 8 - channel) {
case 7: case 7:
ptr_r[1] = zero_ptr; ptr_c1 = zero_ptr;
case 6: case 6:
ptr_r[2] = zero_ptr; ptr_c2 = zero_ptr;
case 5: case 5:
ptr_r[3] = zero_ptr; ptr_c3 = zero_ptr;
case 4: case 4:
ptr_r[4] = zero_ptr; ptr_c4 = zero_ptr;
case 3: case 3:
ptr_r[5] = zero_ptr; ptr_c5 = zero_ptr;
case 2: case 2:
ptr_r[6] = zero_ptr; ptr_c6 = zero_ptr;
case 1: case 1:
ptr_r[7] = zero_ptr; ptr_c7 = zero_ptr;
default: default:
break; break;
} }
} }
//! valid height if (pad_l) {
int j = 0; memset(dout, 0, pad_l * 8 * sizeof(int8_t));
for (int i = hs; i < he; i++) { dout += pad_l * 8;
auto din_r = din_c + i * width; }
for (int k = 0; k < 8; ++k) { if (valid_cnt) {
if (ptr_r[k] != zero_ptr) { int cnt = valid_cnt;
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
/* main loop */ /* main loop */
"1:\n" "1:\n"
"ldr d0, [%[r0]], #8\n" "ldr d0, [%[r0]], #8\n"
"ldr d1, [%[r1]], #8\n" "ldr d1, [%[r1]], #8\n"
"ldr d2, [%[r2]], #8\n" "ldr d2, [%[r2]], #8\n"
"ldr d3, [%[r3]], #8\n" "ldr d3, [%[r3]], #8\n"
"ldr d4, [%[r4]], #8\n" "ldr d4, [%[r4]], #8\n"
"ldr d5, [%[r5]], #8\n" "ldr d5, [%[r5]], #8\n"
"ldr d6, [%[r6]], #8\n" "ldr d6, [%[r6]], #8\n"
"ldr d7, [%[r7]], #8\n" "ldr d7, [%[r7]], #8\n"
"trn1 v8.8b, v0.8b, v1.8b\n" "trn1 v8.8b, v0.8b, v1.8b\n"
"trn2 v9.8b, v0.8b, v1.8b\n" "trn2 v9.8b, v0.8b, v1.8b\n"
"trn1 v10.8b, v2.8b, v3.8b\n" "trn1 v10.8b, v2.8b, v3.8b\n"
"trn2 v11.8b, v2.8b, v3.8b\n" "trn2 v11.8b, v2.8b, v3.8b\n"
"trn1 v12.8b, v4.8b, v5.8b\n" "trn1 v12.8b, v4.8b, v5.8b\n"
"trn2 v13.8b, v4.8b, v5.8b\n" "trn2 v13.8b, v4.8b, v5.8b\n"
"trn1 v14.8b, v6.8b, v7.8b\n" "trn1 v14.8b, v6.8b, v7.8b\n"
"trn2 v15.8b, v6.8b, v7.8b\n" "trn2 v15.8b, v6.8b, v7.8b\n"
"trn1 v0.4h, v8.4h, v10.4h\n" "trn1 v0.4h, v8.4h, v10.4h\n"
"trn2 v1.4h, v8.4h, v10.4h\n" "trn2 v1.4h, v8.4h, v10.4h\n"
"trn1 v2.4h, v9.4h, v11.4h\n" "trn1 v2.4h, v9.4h, v11.4h\n"
"trn2 v3.4h, v9.4h, v11.4h\n" "trn2 v3.4h, v9.4h, v11.4h\n"
"trn1 v4.4h, v12.4h, v14.4h\n" "trn1 v4.4h, v12.4h, v14.4h\n"
"trn2 v5.4h, v12.4h, v14.4h\n" "trn2 v5.4h, v12.4h, v14.4h\n"
"trn1 v6.4h, v13.4h, v15.4h\n" "trn1 v6.4h, v13.4h, v15.4h\n"
"trn2 v7.4h, v13.4h, v15.4h\n" "trn2 v7.4h, v13.4h, v15.4h\n"
"trn1 v8.2s, v0.2s, v4.2s\n" "trn1 v8.2s, v0.2s, v4.2s\n"
"trn1 v9.2s, v2.2s, v6.2s\n" "trn1 v9.2s, v2.2s, v6.2s\n"
"trn1 v10.2s, v1.2s, v5.2s\n" "trn1 v10.2s, v1.2s, v5.2s\n"
"trn1 v11.2s, v3.2s, v7.2s\n" "trn1 v11.2s, v3.2s, v7.2s\n"
"stp d8, d9, [%[ptr_out]], #16\n" "stp d8, d9, [%[ptr_out]], #16\n"
"trn2 v12.2s, v0.2s, v4.2s\n" "trn2 v12.2s, v0.2s, v4.2s\n"
"trn2 v13.2s, v2.2s, v6.2s\n" "trn2 v13.2s, v2.2s, v6.2s\n"
"stp d10, d11, [%[ptr_out]], #16\n" "stp d10, d11, [%[ptr_out]], #16\n"
"trn2 v14.2s, v1.2s, v5.2s\n" "trn2 v14.2s, v1.2s, v5.2s\n"
"trn2 v15.2s, v3.2s, v7.2s\n" "trn2 v15.2s, v3.2s, v7.2s\n"
"subs %w[cnt], %w[cnt], #1\n" "subs %w[cnt], %w[cnt], #1\n"
"stp d12, d13, [%[ptr_out]], #16\n" "stp d12, d13, [%[ptr_out]], #16\n"
"stp d14, d15, [%[ptr_out]], #16\n" "stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n" "bne 1b\n"
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[r0] "+r"(inr0), [r0] "+r"(ptr_c0),
[r1] "+r"(inr1), [r1] "+r"(ptr_c1),
[r2] "+r"(inr2), [r2] "+r"(ptr_c2),
[r3] "+r"(inr3), [r3] "+r"(ptr_c3),
[r4] "+r"(inr4), [r4] "+r"(ptr_c4),
[r5] "+r"(inr5), [r5] "+r"(ptr_c5),
[r6] "+r"(inr6), [r6] "+r"(ptr_c6),
[r7] "+r"(inr7), [r7] "+r"(ptr_c7),
[ptr_out] "+r"(ptr_out) [ptr_out] "+r"(dout)
: :
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15"); "v15");
#else #else
asm volatile( asm volatile(
/* main loop */ /* main loop */
"1:\n" "1:\n"
"vld1.32 {d0}, [%[r0]]!\n" "vld1.32 {d0}, [%[r0]]!\n"
"vld1.32 {d1}, [%[r1]]!\n" "vld1.32 {d1}, [%[r1]]!\n"
"vld1.32 {d2}, [%[r2]]!\n" "vld1.32 {d2}, [%[r2]]!\n"
"vld1.32 {d3}, [%[r3]]!\n" "vld1.32 {d3}, [%[r3]]!\n"
"vld1.32 {d4}, [%[r4]]!\n" "vld1.32 {d4}, [%[r4]]!\n"
"vld1.32 {d5}, [%[r5]]!\n" "vld1.32 {d5}, [%[r5]]!\n"
"vld1.32 {d6}, [%[r6]]!\n" "vld1.32 {d6}, [%[r6]]!\n"
"vld1.32 {d7}, [%[r7]]!\n" "vld1.32 {d7}, [%[r7]]!\n"
"vtrn.8 d0, d1\n" "vtrn.8 d0, d1\n"
"vtrn.8 d2, d3\n" "vtrn.8 d2, d3\n"
"vtrn.8 d4, d5\n" "vtrn.8 d4, d5\n"
"vtrn.8 d6, d7\n" "vtrn.8 d6, d7\n"
"vtrn.16 d0, d2\n" "vtrn.16 d0, d2\n"
"vtrn.16 d1, d3\n" "vtrn.16 d1, d3\n"
"vtrn.16 d4, d6\n" "vtrn.16 d4, d6\n"
"vtrn.16 d5, d7\n" "vtrn.16 d5, d7\n"
"vtrn.32 d0, d4\n" "vtrn.32 d0, d4\n"
"vtrn.32 d2, d6\n" "vtrn.32 d2, d6\n"
"vtrn.32 d1, d5\n" "vtrn.32 d1, d5\n"
"vtrn.32 d3, d7\n" "vtrn.32 d3, d7\n"
"subs %[cnt], #1\n" "subs %[cnt], #1\n"
"vst1.32 {d0-d3}, [%[ptr_out]]!\n" "vst1.32 {d0-d3}, [%[ptr_out]]!\n"
"vst1.32 {d4-d7}, [%[ptr_out]]!\n" "vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n" "bne 1b\n"
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[r0] "+r"(inr0), [r0] "+r"(ptr_c0),
[r1] "+r"(inr1), [r1] "+r"(ptr_c1),
[r2] "+r"(inr2), [r2] "+r"(ptr_c2),
[r3] "+r"(inr3), [r3] "+r"(ptr_c3),
[r4] "+r"(inr4), [r4] "+r"(ptr_c4),
[r5] "+r"(inr5), [r5] "+r"(ptr_c5),
[r6] "+r"(inr6), [r6] "+r"(ptr_c6),
[r7] "+r"(inr7), [r7] "+r"(ptr_c7),
[ptr_out] "+r"(ptr_out) [ptr_out] "+r"(dout)
: :
: "cc", "memory", "q0", "q1", "q2", "q3"); : "cc", "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#endif // aarch64 }
} for (int i = 0; i < remain; ++i) {
for (int k = 0; k < remain; ++k) { dout[0] = *(ptr_c0++);
ptr_out[0] = *(inr0++); dout[1] = *(ptr_c1++);
ptr_out[1] = *(inr1++); dout[2] = *(ptr_c2++);
ptr_out[2] = *(inr2++); dout[3] = *(ptr_c3++);
ptr_out[3] = *(inr3++); dout[4] = *(ptr_c4++);
ptr_out[4] = *(inr4++); dout[5] = *(ptr_c5++);
ptr_out[5] = *(inr5++); dout[6] = *(ptr_c6++);
ptr_out[6] = *(inr6++); dout[7] = *(ptr_c7++);
ptr_out[7] = *(inr7++); dout += 8;
ptr_out += 8; }
} if (pad_r) {
j++; memset(dout, 0, pad_r * 8 * sizeof(int8_t));
dout += pad_r * 8;
} }
} }
TargetFree(TARGET(kARM), ptr_c);
} }
/*wirte result in outputs /*wirte result in outputs
......
...@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din, ...@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din,
bool flag_relu, bool flag_relu,
ARMContext* ctx); ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din,
} }
} }
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
...@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......
...@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32"; VLOG(5) << "invoke 3x3 dw conv fp32";
/// trans weights // trans weights
constexpr int cblock = 4; constexpr int cblock = 4;
auto oc = w_dims[0]; auto oc = w_dims[0];
auto kh = w_dims[2]; auto kh = w_dims[2];
...@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
} }
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out"; VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
...@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
} }
...@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
} }
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out"; VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
...@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
} }
......
...@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
} }
#endif /// 3x3dw #endif /// 3x3dw
#if 0 /// 5x5dw #if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) { for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) { for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
...@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) { for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册