From 4b9df8fbc82634da6b99fe7de528f16f4369789a Mon Sep 17 00:00:00 2001 From: yiicy Date: Wed, 9 Oct 2019 20:37:57 +0800 Subject: [PATCH] improve dw conv performance * imporve prepack_input func speed in int8 3x3s1 dw conv * fix code style * fix code style * improve 3x3s1 dw fp32 conv speed a little * arm add 5x5s1 int8 dw conv, test=develop --- .../arm/math/conv3x3s1_depthwise_fp32.cc | 328 ++--- .../arm/math/conv3x3s1_depthwise_int8.cc | 25 +- .../arm/math/conv3x3s2_depthwise_int8.cc | 24 +- .../arm/math/conv5x5s1_depthwise_int8.cc | 1298 +++++++++-------- lite/backends/arm/math/conv_block_utils.h | 377 +++-- lite/backends/arm/math/conv_depthwise.h | 18 + lite/backends/arm/math/conv_impl.cc | 82 ++ lite/kernels/arm/conv_compute.cc | 4 +- lite/kernels/arm/conv_depthwise.cc | 24 +- lite/tests/math/conv_int8_compute_test.cc | 6 +- 10 files changed, 1205 insertions(+), 981 deletions(-) diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc index a8b243df76..848f4f4210 100644 --- a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc @@ -145,14 +145,17 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outc21 = ptr_write; outc31 = ptr_write; } - auto c00 = outc00; - auto c01 = outc01; - auto c10 = outc10; - auto c11 = outc11; - auto c20 = outc20; - auto c21 = outc21; - auto c30 = outc30; - auto c31 = outc31; + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(flag_relu)}; + void* outl_ptr = reinterpret_cast(outl); for (int w = 0; w < w_loop; ++w) { bool flag_mask = (w == w_loop - 1) && flag_remain; float* out0 = pre_out; @@ -210,6 +213,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ + "ldp x0, x1, [%[outl]] \n" "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ @@ -230,6 +234,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ + "ldp x2, x3, [%[outl], #16] \n" "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ @@ -239,6 +244,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ + "ldp x4, x5, [%[outl], #32] \n" "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ @@ -248,25 +254,83 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ + "ldp x6, x7, [%[outl], #48] \n" "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ - /* save result */ - "stp q15, q16, [%[out]], #32\n" - "stp q17, q18, [%[out]], #32\n" - "stp q19, q20, [%[out]], #32\n" - "stp q21, q22, [%[out]]\n" + + "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ + "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ + "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ + "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ + "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ + "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ + "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ + "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ + + /* transpose */ + "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ + "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ + "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ + "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ + "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ + "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ + "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ + "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ + "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ + "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ + "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ + "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ + "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ + "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ + "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ + "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ + + "cbz %w[flag_relu], 0f\n" /* skip relu*/ + "movi v0.4s, #0\n" /* for relu */ + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "0:\n" + "cbnz %w[flag_mask], 1f\n" + "str q15, [x0]\n" /* save outc00 */ + "str q16, [x4]\n" /* save outc01 */ + "str q17, [x1]\n" /* save outc10 */ + "str q18, [x5]\n" /* save outc11 */ + "str q19, [x2]\n" /* save outc20 */ + "str q20, [x6]\n" /* save outc21 */ + "str q21, [x3]\n" /* save outc30 */ + "str q22, [x7]\n" /* save outc31 */ + "b 2f\n" + "1:\n" + "str q15, [%[out]], #16 \n" /* save remain to pre_out */ + "str q17, [%[out]], #16 \n" /* save remain to pre_out */ + "str q19, [%[out]], #16 \n" /* save remain to pre_out */ + "str q21, [%[out]], #16 \n" /* save remain to pre_out */ + "str q16, [%[out]], #16 \n" /* save remain to pre_out */ + "str q18, [%[out]], #16 \n" /* save remain to pre_out */ + "str q20, [%[out]], #16 \n" /* save remain to pre_out */ + "str q22, [%[out]], #16 \n" /* save remain to pre_out */ + "2:\n" :[inr0] "+r"(inr0), [inr1] "+r"(inr1), [inr2] "+r"(inr2), [inr3] "+r"(inr3), [out]"+r"(out0) :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8) + [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), + [vbias]"w" (vbias), [outl] "r" (outl_ptr), + [flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu) : "cc", "memory", "v0","v1","v2","v3","v4","v5","v6","v7", "v8", "v9", "v10", "v11", "v15", - "v16","v17","v18","v19","v20","v21","v22" + "v16","v17","v18","v19","v20","v21","v22", + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" ); #else asm volatile( @@ -355,183 +419,113 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, "vmla.f32 q8, q5, q2 @ w8 * inr22\n" "vmla.f32 q9, q5, q3 @ w8 * inr23\n" "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" + "ldr r4, [%[outl], #32] @ load bias addr to r4\n" "vmla.f32 q14, q6, q0 @ w5 * inr24\n" "vmla.f32 q15, q6, q1 @ w5 * inr25\n" "vmla.f32 q10, q5, q0 @ w8 * inr24\n" "vmla.f32 q11, q5, q1 @ w8 * inr25\n" "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" /* mul r3 with w6, w7, w8, get out r1 */ "vmla.f32 q12, q7, q2 @ w6 * inr30\n" "vmla.f32 q13, q7, q3 @ w6 * inr31\n" - "vst1.32 {d16-d19}, [%[out0]]! @ save r00, r01, c0~c3\n" "vmla.f32 q14, q7, q0 @ w6 * inr32\n" "vmla.f32 q15, q7, q1 @ w6 * inr33\n" - "vst1.32 {d20-d23}, [%[out0]]! @ save r02, r03, c0~c3\n" "vmla.f32 q12, q4, q3 @ w7 * inr31\n" "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" + "vld1.32 {d12-d13}, [r4] @ load bias\n" "vmla.f32 q13, q4, q0 @ w7 * inr32\n" "vmla.f32 q14, q4, q1 @ w7 * inr33\n" "vmla.f32 q15, q4, q2 @ w7 * inr34\n" + "ldr r0, [%[outl]] @ load outc00 to r0\n" "vmla.f32 q12, q5, q0 @ w8 * inr32\n" "vmla.f32 q13, q5, q1 @ w8 * inr33\n" + "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" "vmla.f32 q14, q5, q2 @ w8 * inr34\n" "vmla.f32 q15, q5, q3 @ w8 * inr35\n" - "vst1.32 {d24-d27}, [%[out0]]! @ save r10, r11, c0~c3\n" - "vst1.32 {d28-d31}, [%[out0]]! @ save r12, r13, c0~c3\n" + "ldr r1, [%[outl], #4] @ load outc10 to r1\n" + "vadd.f32 q8, q8, q6 @ r00 add bias\n" + "vadd.f32 q9, q9, q6 @ r01 add bias\n" + "vadd.f32 q10, q10, q6 @ r02 add bias\n" + "vadd.f32 q11, q11, q6 @ r03 add bias\n" + "ldr r2, [%[outl], #8] @ load outc20 to r2\n" + "vadd.f32 q12, q12, q6 @ r10 add bias\n" + "vadd.f32 q13, q13, q6 @ r11 add bias\n" + "vadd.f32 q14, q14, q6 @ r12 add bias\n" + "vadd.f32 q15, q15, q6 @ r13 add bias\n" + "ldr r3, [%[outl], #12] @ load outc30 to r3\n" + "vmov.u32 q7, #0 @ mov zero to q7\n" + "cmp r5, #0 @ cmp flag relu\n" + "beq 1f @ skip relu\n" + "vmax.f32 q8, q8, q7 @ r00 relu\n" + "vmax.f32 q9, q9, q7 @ r01 relu\n" + "vmax.f32 q10, q10, q7 @ r02 relu\n" + "vmax.f32 q11, q11, q7 @ r03 relu\n" + "vmax.f32 q12, q12, q7 @ r10 relu\n" + "vmax.f32 q13, q13, q7 @ r11 relu\n" + "vmax.f32 q14, q14, q7 @ r12 relu\n" + "vmax.f32 q15, q15, q7 @ r13 relu\n" + "1:\n" + "ldr r4, [%[outl], #16] @ load outc01 to r4\n" + "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" + "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" + "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" + "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" + "ldr r5, [%[outl], #20] @ load outc11 to r5\n" + "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" + "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" + "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" + "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" + "cmp %[flag_mask], #0 @ cmp flag mask\n" + "bne 2f\n" + "vst1.32 {d16-d17}, [r0] @ save outc00\n" + "vst1.32 {d18-d19}, [r1] @ save outc10\n" + "vst1.32 {d20-d21}, [r2] @ save outc20\n" + "vst1.32 {d22-d23}, [r3] @ save outc30\n" + "vst1.32 {d24-d25}, [r4] @ save outc01\n" + "vst1.32 {d26-d27}, [r5] @ save outc11\n" + "ldr r0, [%[outl], #24] @ load outc21 to r0\n" + "ldr r1, [%[outl], #28] @ load outc31 to r1\n" + "vst1.32 {d28-d29}, [r0] @ save outc21\n" + "vst1.32 {d30-d31}, [r1] @ save outc31\n" + "b 3f @ branch end\n" + "2: \n" + "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" + "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" + "3: \n" : [r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [r3] "+r"(inr3), - [out0] "+r"(out0), [wc0] "+r"(weight_c) - : + [r2] "+r"(inr2), [r3] "+r"(inr3), + [out0] "+r"(out0), [wc0] "+r"(weight_c) + : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr) : "cc", "memory", - "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13","q14", "q15" + "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5" ); #endif // __arch64__ - float* out1 = pre_out; - if (flag_mask) { - c00 = outc00; - c01 = outc01; - c10 = outc10; - c11 = outc11; - c20 = outc20; - c21 = outc21; - c30 = outc30; - c31 = outc31; - outc00 = pre_out; - outc01 = pre_out + 4; - outc10 = pre_out + 8; - outc11 = pre_out + 12; - outc20 = pre_out + 16; - outc21 = pre_out + 20; - outc30 = pre_out + 24; - outc31 = pre_out + 28; - } -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[din]], #32\n" /* load input*/ - "ldp q2, q3, [%[din]], #32\n" /* load input*/ - "fadd v15.4s, v0.4s, %[vbias].4s\n" /* add bias */ - "fadd v16.4s, v1.4s, %[vbias].4s\n" /* add bias */ - "ldp q4, q5, [%[din]], #32\n" /* load input*/ - "fadd v17.4s, v2.4s, %[vbias].4s\n" /* add bias */ - "fadd v18.4s, v3.4s, %[vbias].4s\n" /* add bias */ - "ldp q6, q7, [%[din]]\n" /* load input*/ - "fadd v19.4s, v4.4s, %[vbias].4s\n" /* add bias */ - "fadd v20.4s, v5.4s, %[vbias].4s\n" /* add bias */ - "fadd v21.4s, v6.4s, %[vbias].4s\n" /* add bias */ - "fadd v22.4s, v7.4s, %[vbias].4s\n" /* add bias */ - /* transpose */ - "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ - "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ - "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ - "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ - "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ - "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ - "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ - "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ - "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v15.4s, v15.4s, v0.4s\n" - "fmax v16.4s, v16.4s, v0.4s\n" - "fmax v17.4s, v17.4s, v0.4s\n" - "fmax v18.4s, v18.4s, v0.4s\n" - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - "0:\n" - "str q15, [%[outc00]], #16\n" /* save outc00*/ - "str q16, [%[outc01]], #16\n" /* save outc01*/ - "str q17, [%[outc10]], #16\n" /* save outc10*/ - "str q18, [%[outc11]], #16\n" /* save outc11*/ - "str q19, [%[outc20]], #16\n" /* save outc20*/ - "str q20, [%[outc21]], #16\n" /* save outc21*/ - "str q21, [%[outc30]], #16\n" /* save outc30*/ - "str q22, [%[outc31]], #16\n" /* save outc31*/ - :[outc00] "+r"(outc00), [outc01] "+r"(outc01), - [outc10] "+r"(outc10), [outc11] "+r"(outc11), - [outc20] "+r"(outc20), [outc21] "+r"(outc21), - [outc30] "+r"(outc30), [outc31] "+r"(outc31), - [din] "+r"(out1) - :[vbias]"w" (vbias), [flag_relu] "r"(flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v15", "v16","v17","v18","v19","v20","v21","v22" - ); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[din]]!\n" /* load input*/ - "vld1.32 {d4-d7}, [%[din]]!\n" /* load input*/ - "vadd.f32 q0, q0, %q[vbias]\n" /* add bias */ - "vadd.f32 q1, q1, %q[vbias]\n" /* add bias */ - "vld1.32 {d8-d11}, [%[din]]!\n" /* load input*/ - "vadd.f32 q2, q2, %q[vbias]\n" /* add bias */ - "vadd.f32 q3, q3, %q[vbias]\n" /* add bias */ - "vld1.32 {d12-d15}, [%[din]]!\n" /* load input*/ - "vadd.f32 q4, q4, %q[vbias]\n" /* add bias */ - "vadd.f32 q5, q5, %q[vbias]\n" /* add bias */ - "vadd.f32 q6, q6, %q[vbias]\n" /* add bias */ - "vadd.f32 q7, q7, %q[vbias]\n" /* add bias */ - /* transpose */ - "vtrn.32 q0, q1\n" /* r0: q0: a0a1c0c1, q1: b0b1d0d1*/ - "vtrn.32 q2, q3\n" /* r0: q2: a2a3c2c3, q3: b2b3d2d3*/ - "vtrn.32 q4, q5\n" /* r1: q4: a0a1c0c1, q5: b0b1d0d1*/ - "vtrn.32 q6, q7\n" /* r1: q6: a2a3c2c3, q7: b2b3d2d3*/ - "vswp d1, d4\n" /* r0: q0: a0a1a2a3, q2: c0c1c2c3*/ - "vswp d3, d6\n" /* r0: q1: b0b1b2b3, q3: d0d1d2d3*/ - "vswp d9, d12\n" /* r1: q4: a0a1a2a3, q6: c0c1c2c3*/ - "vswp d11, d14\n" /* r1: q5: b0b1b2b3, q7: d0d1d2d3*/ - "cmp %[flag_relu], #0\n" - "beq 0f\n" /* skip relu*/ - "vmov.u32 q15, #0\n" - "vmax.f32 q0, q0, q15\n" - "vmax.f32 q1, q1, q15\n" - "vmax.f32 q2, q2, q15\n" - "vmax.f32 q3, q3, q15\n" - "vmax.f32 q4, q4, q15\n" - "vmax.f32 q5, q5, q15\n" - "vmax.f32 q6, q6, q15\n" - "vmax.f32 q7, q7, q15\n" - "0:\n" - "vst1.32 {d0-d1}, [%[outc00]]!\n" /* save outc00*/ - "vst1.32 {d2-d3}, [%[outc10]]!\n" /* save outc10*/ - "vst1.32 {d4-d5}, [%[outc20]]!\n" /* save outc20*/ - "vst1.32 {d6-d7}, [%[outc30]]!\n" /* save outc30*/ - "vst1.32 {d8-d9}, [%[outc01]]!\n" /* save outc01*/ - "vst1.32 {d10-d11}, [%[outc11]]!\n" /* save outc11*/ - "vst1.32 {d12-d13}, [%[outc21]]!\n" /* save outc21*/ - "vst1.32 {d14-d15}, [%[outc31]]!\n" /* save outc31*/ - :[outc00] "+r"(outc00), [outc01] "+r"(outc01), - [outc10] "+r"(outc10), [outc11] "+r"(outc11), - [outc20] "+r"(outc20), [outc21] "+r"(outc21), - [outc30] "+r"(outc30), [outc31] "+r"(outc31), - [din] "+r"(out1) - :[vbias]"w" (vbias), [flag_relu] "r"(flag_relu) - : "cc", "memory", - "q0","q1","q2","q3","q4","q5","q6","q7", "q15" - ); -#endif // __aarch64__ // clang-format on + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; if (flag_mask) { - for (int i = 0; i < remain; ++i) { - c00[i] = pre_out[i]; - c01[i] = pre_out[i + 4]; - c10[i] = pre_out[i + 8]; - c11[i] = pre_out[i + 12]; - c20[i] = pre_out[i + 16]; - c21[i] = pre_out[i + 20]; - c30[i] = pre_out[i + 24]; - c31[i] = pre_out[i + 28]; - } + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); } } } diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc index 7e77bf9d08..bc2097b928 100644 --- a/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc @@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, const int win_round = wout_round + 2; //! get h block - //! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round - //! * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! win_round = wout_round + 2 //! hin_r_block = hout_r_block + 2 - int hout_r_block = - (llc_size - 2 * win_round * threads) / - (win_round * threads + hout_c_block * wout_round * threads * 4); + int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) / + (win_round * threads * hout_c_block + + hout_c_block * wout_round * threads * 4); hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; @@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); auto pre_din = tmp_din; #endif - prepack_input_nxw_c8_int8(din_batch, - pre_din, - c, - c + hout_c_block, - hs, - he, - ws, - we, - chin, - win, - hin); + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + const int8_t* block_inr0 = pre_din; const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr2 = block_inr1 + in_len; diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc index d272b62508..2e475fc606 100644 --- a/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc @@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, const int win_round = wout_round * 2 /*stride*/ + 1; //! get h block - //! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round - //! * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! llc_size = threads * win_round * hin_r_block * hout_c_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! win_round = wout_round + 2 //! hin_r_block = hout_r_block + 2 - int hout_r_block = - (llc_size - 2 * win_round * threads) / - (2 * win_round * threads + hout_c_block * wout_round * threads * 4); + int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) / + (2 * win_round * threads * hout_c_block + + hout_c_block * wout_round * threads * 4); hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; @@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); auto pre_din = tmp_din; #endif - prepack_input_nxw_c8_int8(din_batch, - pre_din, - c, - c + hout_c_block, - hs, - he, - ws, - we, - chin, - win, - hin); + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); const int8_t* block_inr0 = pre_din; const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr2 = block_inr1 + in_len; diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc index 0d0034dd85..802082048c 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc @@ -14,6 +14,7 @@ #include #include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" #include "lite/backends/arm/math/conv_impl.h" #include "lite/core/context.h" #include "lite/operators/op_params.h" @@ -26,592 +27,749 @@ namespace lite { namespace arm { namespace math { -void conv_depthwise_5x5s1_int8(int32_t* dout, +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +void conv_depthwise_5x5s1_int8(Dtype* dout, const int8_t* din, const int8_t* weights, - const int* bias, + const float* scale, + const float* bias, bool flag_bias, bool flag_relu, - const int num, - const int chin, - const int hin, - const int win, - const int hout, - const int wout, - ARMContext* ctx, - PrecisionType out_type, - const float* scale); - -void conv_depthwise_5x5_int8(const int8_t* din, - int32_t* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const int8_t* weights, - const int32_t* bias, - const operators::ConvParam& param, - ARMContext* ctx, - PrecisionType out_type, - const float* scale) { - int stride_h = param.strides[0]; - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - // if (param.activation_param.has_active){ - // if (param.activation_param.active == Active_relu || - // fabs(param.activation_param.negative_slope) > 1e-6f){ - // flag_relu = true; - // } - // } - if (stride_h == 1) { -#ifdef __aarch64__ - conv_depthwise_5x5s1_int8(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - chin, - hin, - win, - hout, - wout, - ctx, - out_type, - scale); -#else - - LOG(FATAL) << "5x5 dw conv armv7 has not impl"; -#endif - } -} - -/** - * \brief depthwise convolution, kernel size 5x5, stride 1, pad 1, with bias, - * width > 4 - */ -// 2 line -#ifdef __aarch64__ + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx) { + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 8; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 4 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + int8_t* tmp_din = tmp_work_space; -template -inline void prefetch(const Dtype* din) { -#ifdef __aarch64__ - asm volatile("PRFM PLDL1KEEP, [%[din]] \n" : : [din] "r"(din) : "memory"); -#else - asm volatile("pld [%[din]] \n" : : [din] "r"(din) : "memory"); -#endif -} - -void conv_depthwise_5x5s1_int8( - int32_t* dout, - const int8_t* din, - const int8_t* weights, - const int32_t* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int chin, - const int hin, - const int win, - const int hout, - const int wout, - ARMContext* ctx, - PrecisionType od_type, - float const* scales) { /// scale_size = channel-out - - // printf("5*5 multiply\n"); int size_in_channel = win * hin; int size_out_channel = wout * hout; - int w_stride = 5 * 5; - - static int const stride_w = 1; - int const stride_h = stride_w; - int const chout = chin; - int const pad_w = 2; - int const pad_h = pad_w; - - int const wout_round = ((wout + 7) / 8) * 8; - int const win_round = wout_round * stride_w + 5 - 1; - int const hout_round = ((hout + 2) / 3) * 3; - int const hin_round = hout_round * stride_h + 5 - 1; - int const tile_h = hout_round / 3; - int const tile_w = wout_round / 8; - - int const pre_in_size = hin_round * win_round; - int const pre_out_size = hout_round * wout_round; - int const pre_io_size = pre_in_size + pre_out_size * sizeof(int); - - int const hs = -pad_h; - int const he = hs + hin_round; - int const ws = -pad_w; - int const we = ws + win_round; - - // signed char* tmp_work_space = new signed char [1024*5]; - signed char* tmp_work_space = ctx->workspace_data(); - signed char* ptr_zero = tmp_work_space; - int* ptr_write = reinterpret_cast(ptr_zero + win_round); - signed char* pre_data = - reinterpret_cast(ptr_write + wout_round); - - memset(ptr_zero, 0, win_round * sizeof(signed char)); + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + int out_row_stride = hout_c_block * wout_round; for (int n = 0; n < num; ++n) { - signed char const* din_batch = din + n * chin * size_in_channel; - int* dout_batch = dout + n * chout * size_out_channel; + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = reinterpret_cast(dout) + + n * chout * size_out_channel * sizeof(Dtype); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; - // #pragma omp parallel for - for (int c = 0; c < chout; c++) { +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { #ifdef ARM_WITH_OMP - int const thno = omp_get_thread_num(); + int8_t* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); #else - int const thno = 0; + int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); + auto pre_din = tmp_din; #endif - signed char const* din_channel = din_batch + c * size_in_channel; - signed char* pre_din = pre_data + thno * pre_io_size; - int* pre_out = reinterpret_cast(pre_din + pre_in_size); - int* dout_ptr = pre_out; - - prepack_input_nxw(din_channel, - pre_din, - c, - c + 1, - hs, - he, - ws, - we, - 1, - win, - hin, - ptr_zero); - - signed char const* wei_ptr = weights + c * w_stride; - int bias_val = flag_bias ? bias[c] : 0.f; - - int8x8_t wr00 = vdup_n_s8(wei_ptr[0 * 5 + 0]); - int8x8_t wr01 = vdup_n_s8(wei_ptr[0 * 5 + 1]); - int8x8_t wr02 = vdup_n_s8(wei_ptr[0 * 5 + 2]); - int8x8_t wr03 = vdup_n_s8(wei_ptr[0 * 5 + 3]); - int8x8_t wr04 = vdup_n_s8(wei_ptr[0 * 5 + 4]); - - int8x8_t wr10 = vdup_n_s8(wei_ptr[1 * 5 + 0]); - int8x8_t wr11 = vdup_n_s8(wei_ptr[1 * 5 + 1]); - int8x8_t wr12 = vdup_n_s8(wei_ptr[1 * 5 + 2]); - int8x8_t wr13 = vdup_n_s8(wei_ptr[1 * 5 + 3]); - int8x8_t wr14 = vdup_n_s8(wei_ptr[1 * 5 + 4]); - - int8x8_t wr20 = vdup_n_s8(wei_ptr[2 * 5 + 0]); - int8x8_t wr21 = vdup_n_s8(wei_ptr[2 * 5 + 1]); - int8x8_t wr22 = vdup_n_s8(wei_ptr[2 * 5 + 2]); - int8x8_t wr23 = vdup_n_s8(wei_ptr[2 * 5 + 3]); - int8x8_t wr24 = vdup_n_s8(wei_ptr[2 * 5 + 4]); - - int8x8_t wr30 = vdup_n_s8(wei_ptr[3 * 5 + 0]); - int8x8_t wr31 = vdup_n_s8(wei_ptr[3 * 5 + 1]); - int8x8_t wr32 = vdup_n_s8(wei_ptr[3 * 5 + 2]); - int8x8_t wr33 = vdup_n_s8(wei_ptr[3 * 5 + 3]); - int8x8_t wr34 = vdup_n_s8(wei_ptr[3 * 5 + 4]); - - int8x8_t wr40 = vdup_n_s8(wei_ptr[4 * 5 + 0]); - int8x8_t wr41 = vdup_n_s8(wei_ptr[4 * 5 + 1]); - int8x8_t wr42 = vdup_n_s8(wei_ptr[4 * 5 + 2]); - int8x8_t wr43 = vdup_n_s8(wei_ptr[4 * 5 + 3]); - int8x8_t wr44 = vdup_n_s8(wei_ptr[4 * 5 + 4]); - - int* doutr0 = nullptr; - int* doutr1 = nullptr; - int* doutr2 = nullptr; - - signed char const* dr0 = pre_din; - signed char const* dr1 = dr0 + win_round; - signed char const* dr2 = dr1 + win_round; - signed char const* dr3 = dr2 + win_round; - signed char const* dr4 = dr3 + win_round; - signed char const* dr5 = dr4 + win_round; - signed char const* dr6 = dr5 + win_round; - - signed char const* din_ptr0 = nullptr; - signed char const* din_ptr1 = nullptr; - signed char const* din_ptr2 = nullptr; - signed char const* din_ptr3 = nullptr; - signed char const* din_ptr4 = nullptr; - signed char const* din_ptr5 = nullptr; - signed char const* din_ptr6 = nullptr; - - for (int h = 0; h < tile_h; h++) { - // printf("c:%d h:%d\n", c, h); - doutr0 = dout_ptr; - doutr1 = doutr0 + wout_round; - doutr2 = doutr1 + wout_round; - - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - din_ptr6 = dr6; - - prefetch(doutr0); - prefetch(doutr1); - prefetch(doutr2); - prefetch(din_ptr0); - prefetch(din_ptr1); - prefetch(din_ptr2); - prefetch(din_ptr3); - prefetch(din_ptr4); - prefetch(din_ptr5); - prefetch(din_ptr6); - - for (int j = 0; j < tile_w; ++j) { - // printf("j:%d\n", j); - int32x4_t voutr00 = vdupq_n_s32(bias_val); - int32x4_t voutr01 = vdupq_n_s32(bias_val); - int32x4_t voutr10 = vdupq_n_s32(bias_val); - int32x4_t voutr11 = vdupq_n_s32(bias_val); - int32x4_t voutr20 = vdupq_n_s32(bias_val); - int32x4_t voutr21 = vdupq_n_s32(bias_val); - - // din data - int8x8_t vinr00 = vld1_s8(din_ptr0 + 0); - int8x8_t vinr01 = vld1_s8(din_ptr0 + 8); - int8x8_t vinr10 = vld1_s8(din_ptr1 + 0); - int8x8_t vinr11 = vld1_s8(din_ptr1 + 8); - int8x8_t vinr20 = vld1_s8(din_ptr2 + 0); - int8x8_t vinr21 = vld1_s8(din_ptr2 + 8); - int8x8_t vinr30 = vld1_s8(din_ptr3 + 0); - int8x8_t vinr31 = vld1_s8(din_ptr3 + 8); - int8x8_t vinr40 = vld1_s8(din_ptr4 + 0); - int8x8_t vinr41 = vld1_s8(din_ptr4 + 8); - int8x8_t vinr50 = vld1_s8(din_ptr5 + 0); - int8x8_t vinr51 = vld1_s8(din_ptr5 + 8); - int8x8_t vinr60 = vld1_s8(din_ptr6 + 0); - int8x8_t vinr61 = vld1_s8(din_ptr6 + 8); - - /// the first row - // r0 - int8x8_t vtmp1 = vext_s8(vinr00, vinr01, 1); // 12345678 - int8x8_t vtmp2 = vext_s8(vinr00, vinr01, 2); // 2345678 - int8x8_t vtmp3 = vext_s8(vinr00, vinr01, 3); // 345678 - int8x8_t vtmp4 = vext_s8(vinr00, vinr01, 4); // 45678 - - int16x8_t tvoutr0 = vmull_s8(vinr00, wr00); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr01); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr02); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr03); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr04); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - // r1 - vtmp1 = vext_s8(vinr10, vinr11, 1); // 12345678 - vtmp2 = vext_s8(vinr10, vinr11, 2); // 2345678 - vtmp3 = vext_s8(vinr10, vinr11, 3); // 345678 - vtmp4 = vext_s8(vinr10, vinr11, 4); // 45678 - - tvoutr0 = vmull_s8(vinr10, wr10); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr11); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr12); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr13); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr14); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - int16x8_t tvoutr1 = vmull_s8(vinr10, wr00); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr01); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr02); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr03); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr04); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - // r2 - vtmp1 = vext_s8(vinr20, vinr21, 1); // 12345678 - vtmp2 = vext_s8(vinr20, vinr21, 2); // 2345678 - vtmp3 = vext_s8(vinr20, vinr21, 3); // 345678 - vtmp4 = vext_s8(vinr20, vinr21, 4); // 45678 - - tvoutr0 = vmull_s8(vinr20, wr20); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr21); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr22); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr23); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr24); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr20, wr10); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr11); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr12); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr13); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr14); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - int16x8_t tvoutr2 = vmull_s8(vinr20, wr00); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr01); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr02); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr03); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr04); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r3 - vtmp1 = vext_s8(vinr30, vinr31, 1); // 12345678 - vtmp2 = vext_s8(vinr30, vinr31, 2); // 2345678 - vtmp3 = vext_s8(vinr30, vinr31, 3); // 345678 - vtmp4 = vext_s8(vinr30, vinr31, 4); // 45678 - - tvoutr0 = vmull_s8(vinr30, wr30); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr31); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr32); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr33); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr34); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr30, wr20); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr21); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr22); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr23); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr24); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr30, wr10); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr11); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr12); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr13); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr14); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r4 - vtmp1 = vext_s8(vinr40, vinr41, 1); // 12345678 - vtmp2 = vext_s8(vinr40, vinr41, 2); // 2345678 - vtmp3 = vext_s8(vinr40, vinr41, 3); // 345678 - vtmp4 = vext_s8(vinr40, vinr41, 4); // 45678 - - tvoutr0 = vmull_s8(vinr40, wr40); - tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr41); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp2, wr42); - tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr43); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - tvoutr0 = vmull_s8(vtmp4, wr44); - voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); - voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); - - tvoutr1 = vmull_s8(vinr40, wr30); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr31); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr32); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr33); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr34); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr40, wr20); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr21); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr22); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr23); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr24); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r5 - vtmp1 = vext_s8(vinr50, vinr51, 1); // 12345678 - vtmp2 = vext_s8(vinr50, vinr51, 2); // 2345678 - vtmp3 = vext_s8(vinr50, vinr51, 3); // 345678 - vtmp4 = vext_s8(vinr50, vinr51, 4); // 45678 - - tvoutr1 = vmull_s8(vinr50, wr40); - tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr41); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp2, wr42); - tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr43); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - tvoutr1 = vmull_s8(vtmp4, wr44); - voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); - voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); - - tvoutr2 = vmull_s8(vinr50, wr30); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr31); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr32); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr33); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr34); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - // r6 - vtmp1 = vext_s8(vinr60, vinr61, 1); // 12345678 - vtmp2 = vext_s8(vinr60, vinr61, 2); // 2345678 - vtmp3 = vext_s8(vinr60, vinr61, 3); // 345678 - vtmp4 = vext_s8(vinr60, vinr61, 4); // 45678 - - tvoutr2 = vmull_s8(vinr60, wr40); - tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr41); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp2, wr42); - tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr43); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - tvoutr2 = vmull_s8(vtmp4, wr44); - voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); - voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); - - /// data shift 8 bytes - din_ptr0 += 8; - din_ptr1 += 8; - din_ptr2 += 8; - din_ptr3 += 8; - din_ptr4 += 8; - din_ptr5 += 8; - din_ptr6 += 8; - - /// store - vst1q_s32(doutr0, voutr00); - vst1q_s32(doutr1, voutr10); - vst1q_s32(doutr2, voutr20); - doutr0 += 4; - doutr1 += 4; - doutr2 += 4; - vst1q_s32(doutr0, voutr01); - vst1q_s32(doutr1, voutr11); - vst1q_s32(doutr2, voutr21); - doutr0 += 4; - doutr1 += 4; - doutr2 += 4; - } /// end of tile_w - - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - dr3 = dr6; - dr4 = dr3 + win_round; - dr5 = dr4 + win_round; - dr6 = dr5 + win_round; - - dout_ptr = dout_ptr + 3 * wout_round; - } /// end of tile_h - - if (scales == 0) { - write_to_output_numc(pre_out, - dout_batch, - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - ptr_write); - } else if (od_type == PRECISION(kFloat)) { - write2_to_output_numc(pre_out, - reinterpret_cast(dout_batch), - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - scales); - } else if (od_type == PRECISION(kInt8)) { - write2_to_output_numc(pre_out, - reinterpret_cast(dout_batch), - 1, - hout_round, - c, - c + 1, - 0, - hout, - 0, - wout_round, - chout, - hout, - wout, - flag_relu, - reinterpret_cast(ptr_write), - scales); + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + const int8_t* block_inr3 = block_inr2 + in_len; + const int8_t* block_inr4 = block_inr3 + in_len; + + const int8_t* weight_c = weights + c * w_stride; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + const int8_t* inr3 = block_inr3; + const int8_t* inr4 = block_inr4; + + int32_t* ptr_out0 = pre_out + hk * out_row_stride; +// clang-format off +#ifdef __aarch64__ + auto wptr = weight_c; + asm volatile( + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "1:\n" + /* in r0 */ + "smull v20.8h, v0.8b, v8.8b\n" /* w0, int16, out0 */ + "smull v21.8h, v1.8b, v8.8b\n" /* w0, int16, out1 */ + "smull v22.8h, v2.8b, v8.8b\n" /* w0, int16, out2 */ + "smull v23.8h, v3.8b, v8.8b\n" /* w0, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]]\n" /* load r0 4-7 */ + "smlal v20.8h, v1.8b, v9.8b\n" /* w1, int16, out0 */ + "smlal v21.8h, v2.8b, v9.8b\n" /* w1, int16, out1 */ + "smlal v22.8h, v3.8b, v9.8b\n" /* w1, int16, out2 */ + "smlal v23.8h, v4.8b, v9.8b\n" /* w1, int16, out3 */ + "sxtl v24.4s, v20.4h\n" /* mov to out0 low */ + "sxtl2 v25.4s, v20.8h\n" /* mov to out0 hig */ + "sxtl v26.4s, v21.4h\n" /* mov to out1 low */ + "sxtl2 v27.4s, v21.8h\n" /* mov to out1 hig */ + "sxtl v28.4s, v22.4h\n" /* mov to out2 low */ + "sxtl2 v29.4s, v22.8h\n" /* mov to out2 hig */ + "sxtl v30.4s, v23.4h\n" /* mov to out3 low */ + "sxtl2 v31.4s, v23.8h\n" /* mov to out3 hig */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 4-7 */ + + "smull v20.8h, v2.8b, v10.8b\n" /* w2, int16, out0 */ + "smull v21.8h, v3.8b, v10.8b\n" /* w2, int16, out1 */ + "smull v22.8h, v4.8b, v10.8b\n" /* w2, int16, out2 */ + "smull v23.8h, v5.8b, v10.8b\n" /* w2, int16, out3 */ + "smlal v20.8h, v3.8b, v11.8b\n" /* w3, int16, out0 */ + "smlal v21.8h, v4.8b, v11.8b\n" /* w3, int16, out1 */ + "smlal v22.8h, v5.8b, v11.8b\n" /* w3, int16, out2 */ + "smlal v23.8h, v6.8b, v11.8b\n" /* w3, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r1]], #32\n" /* load r1 0-3 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v4.8b, v12.8b\n" /* w4, int16, out0 */ + "smull v21.8h, v5.8b, v12.8b\n" /* w4, int16, out1 */ + "smull v22.8h, v6.8b, v12.8b\n" /* w4, int16, out2 */ + "smull v23.8h, v7.8b, v12.8b\n" /* w4, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r1]]\n" /* load r1 4-7 */ + /* in r1 */ + "smlal v20.8h, v0.8b, v13.8b\n" /* w5, int16, out0 */ + "smlal v21.8h, v1.8b, v13.8b\n" /* w5, int16, out1 */ + "smlal v22.8h, v2.8b, v13.8b\n" /* w5, int16, out2 */ + "smlal v23.8h, v3.8b, v13.8b\n" /* w5, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 8-11 */ + + "smull v20.8h, v1.8b, v14.8b\n" /* w6, int16, out0 */ + "smull v21.8h, v2.8b, v14.8b\n" /* w6, int16, out1 */ + "smull v22.8h, v3.8b, v14.8b\n" /* w6, int16, out2 */ + "smull v23.8h, v4.8b, v14.8b\n" /* w6, int16, out3 */ + "smlal v20.8h, v2.8b, v15.8b\n" /* w7, int16, out0 */ + "smlal v21.8h, v3.8b, v15.8b\n" /* w7, int16, out1 */ + "smlal v22.8h, v4.8b, v15.8b\n" /* w7, int16, out2 */ + "smlal v23.8h, v5.8b, v15.8b\n" /* w7, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v16.8b\n" /* w8, int16, out0 */ + "smull v21.8h, v4.8b, v16.8b\n" /* w8, int16, out1 */ + "smull v22.8h, v5.8b, v16.8b\n" /* w8, int16, out2 */ + "smull v23.8h, v6.8b, v16.8b\n" /* w8, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r2]], #32\n" /* load r2 0-3 */ + "smlal v20.8h, v4.8b, v17.8b\n" /* w9, int16, out0 */ + "smlal v21.8h, v5.8b, v17.8b\n" /* w9, int16, out1 */ + "smlal v22.8h, v6.8b, v17.8b\n" /* w9, int16, out2 */ + "smlal v23.8h, v7.8b, v17.8b\n" /* w9, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r2]]\n" /* load r2 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r2 */ + "smull v20.8h, v0.8b, v18.8b\n" /* w10, int16, out0 */ + "smull v21.8h, v1.8b, v18.8b\n" /* w10, int16, out1 */ + "smull v22.8h, v2.8b, v18.8b\n" /* w10, int16, out2 */ + "smull v23.8h, v3.8b, v18.8b\n" /* w10, int16, out3 */ + "smlal v20.8h, v1.8b, v19.8b\n" /* w11, int16, out0 */ + "smlal v21.8h, v2.8b, v19.8b\n" /* w11, int16, out1 */ + "smlal v22.8h, v3.8b, v19.8b\n" /* w11, int16, out2 */ + "smlal v23.8h, v4.8b, v19.8b\n" /* w11, int16, out3 */ + + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 12-15 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v2.8b, v8.8b\n" /* w12, int16, out0 */ + "smull v21.8h, v3.8b, v8.8b\n" /* w12, int16, out1 */ + "smull v22.8h, v4.8b, v8.8b\n" /* w12, int16, out2 */ + "smull v23.8h, v5.8b, v8.8b\n" /* w12, int16, out3 */ + "smlal v20.8h, v3.8b, v9.8b\n" /* w13, int16, out0 */ + "smlal v21.8h, v4.8b, v9.8b\n" /* w13, int16, out1 */ + "smlal v22.8h, v5.8b, v9.8b\n" /* w13, int16, out2 */ + "smlal v23.8h, v6.8b, v9.8b\n" /* w13, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r3]], #32\n" /* load r3 0-3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "smull v20.8h, v4.8b, v10.8b\n" /* w14, int16, out0 */ + "smull v21.8h, v5.8b, v10.8b\n" /* w14, int16, out1 */ + "smull v22.8h, v6.8b, v10.8b\n" /* w14, int16, out2 */ + "smull v23.8h, v7.8b, v10.8b\n" /* w14, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r3]]\n" /* load r3 4-7 */ + /* in r3 */ + "smlal v20.8h, v0.8b, v11.8b\n" /* w15, int16, out0 */ + "smlal v21.8h, v1.8b, v11.8b\n" /* w15, int16, out1 */ + "smlal v22.8h, v2.8b, v11.8b\n" /* w15, int16, out2 */ + "smlal v23.8h, v3.8b, v11.8b\n" /* w15, int16, out3 */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 16-19 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v1.8b, v12.8b\n" /* w16, int16, out0 */ + "smull v21.8h, v2.8b, v12.8b\n" /* w16, int16, out1 */ + "smull v22.8h, v3.8b, v12.8b\n" /* w16, int16, out2 */ + "smull v23.8h, v4.8b, v12.8b\n" /* w16, int16, out3 */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 20-23 */ + "smlal v20.8h, v2.8b, v13.8b\n" /* w17, int16, out0 */ + "smlal v21.8h, v3.8b, v13.8b\n" /* w17, int16, out1 */ + "smlal v22.8h, v4.8b, v13.8b\n" /* w17, int16, out2 */ + "smlal v23.8h, v5.8b, v13.8b\n" /* w17, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v14.8b\n" /* w18, int16, out0 */ + "smull v21.8h, v4.8b, v14.8b\n" /* w18, int16, out1 */ + "smull v22.8h, v5.8b, v14.8b\n" /* w18, int16, out2 */ + "smull v23.8h, v6.8b, v14.8b\n" /* w18, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r4]], #32\n" /* load r4 0-3 */ + "smlal v20.8h, v4.8b, v15.8b\n" /* w19, int16, out0 */ + "smlal v21.8h, v5.8b, v15.8b\n" /* w19, int16, out1 */ + "smlal v22.8h, v6.8b, v15.8b\n" /* w19, int16, out2 */ + "smlal v23.8h, v7.8b, v15.8b\n" /* w19, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r4]]\n" /* load r4 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r4 */ + "smull v20.8h, v0.8b, v16.8b\n" /* w20, int16, out0 */ + "smull v21.8h, v1.8b, v16.8b\n" /* w20, int16, out1 */ + "smull v22.8h, v2.8b, v16.8b\n" /* w20, int16, out2 */ + "smull v23.8h, v3.8b, v16.8b\n" /* w20, int16, out3 */ + "smlal v20.8h, v1.8b, v17.8b\n" /* w21, int16, out0 */ + "smlal v21.8h, v2.8b, v17.8b\n" /* w21, int16, out1 */ + "smlal v22.8h, v3.8b, v17.8b\n" /* w21, int16, out2 */ + "smlal v23.8h, v4.8b, v17.8b\n" /* w21, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "ld1 {v12.8b}, [%[wc]], #8\n" /* load wc 24 */ + "smull v20.8h, v2.8b, v18.8b\n" /* w22, int16, out0 */ + "smull v21.8h, v3.8b, v18.8b\n" /* w22, int16, out1 */ + "smull v22.8h, v4.8b, v18.8b\n" /* w22, int16, out2 */ + "smull v23.8h, v5.8b, v18.8b\n" /* w22, int16, out3 */ + "smlal v20.8h, v3.8b, v19.8b\n" /* w23, int16, out0 */ + "smlal v21.8h, v4.8b, v19.8b\n" /* w23, int16, out1 */ + "smlal v22.8h, v5.8b, v19.8b\n" /* w23, int16, out2 */ + "smlal v23.8h, v6.8b, v19.8b\n" /* w23, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "sub %[wc], %[wc], #200 \n" + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "smull v20.8h, v4.8b, v12.8b\n" /* w24, int16, out0 */ + "smull v21.8h, v5.8b, v12.8b\n" /* w24, int16, out1 */ + "smull v22.8h, v6.8b, v12.8b\n" /* w24, int16, out2 */ + "smull v23.8h, v7.8b, v12.8b\n" /* w24, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "stp q24, q25, [%[ptr_out0]], #32\n" + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "stp q26, q27, [%[ptr_out0]], #32\n" + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "subs %w[cnt], %w[cnt], #1\n" + "stp q28, q29, [%[ptr_out0]], #32\n" + "stp q30, q31, [%[ptr_out0]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc] "+r"(wptr), + [ptr_out0] "+r"(ptr_out0) + : + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31" + ); +#else + auto wptr = weight_c; + asm volatile( + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "1:\n" + /* inr0 */ + "vmull.s8 q4, d0, d6\n" /* int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* int16, out3 */ + "vmovl.s16 q8, d8\n" /* mov to out0 low */ + "vmovl.s16 q9, d9\n" /* mov to out0 hig */ + "vmovl.s16 q10, d10\n" /* mov to out1 low */ + "vmovl.s16 q11, d11\n" /* mov to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */ + "vmovl.s16 q12, d12\n" /* mov to out2 low */ + "vmovl.s16 q13, d13\n" /* mov to out2 hig */ + "vmovl.s16 q14, d14\n" /* mov to out3 low */ + "vmovl.s16 q15, d15\n" /* mov to out3 hig */ + "vld1.32 {d0-d1}, [%[r0]]\n" /* load r0, 6-7 */ + + "vmull.s8 q4, d2, d6\n" /* w2, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w2, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w2, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w2, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w3, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w3, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w3, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w3, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */ + "sub %[r0], %[r0], #16\n" /* r0 = r0 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w4, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w4, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w4, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w4, int16, out3 */ + "vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */ + /* inr1 */ + "vmlal.s8 q4, d0, d7\n" /* w5, int16, out0 */ + "vmlal.s8 q5, d1, d7\n" /* w5, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w5, int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* w5, int16, out3 */ + "vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4-5 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w6, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w6, int16, out1 */ + "vmull.s8 q6, d3, d6\n" /* w6, int16, out2 */ + "vmull.s8 q7, d4, d6\n" /* w6, int16, out3 */ + "vld1.32 {d0-d1}, [%[r1]]\n" /* load r1, 6-7 */ + "vmlal.s8 q4, d2, d7\n" /* w7, int16, out0 */ + "vmlal.s8 q5, d3, d7\n" /* w7, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w7, int16, out2 */ + "vmlal.s8 q7, d5, d7\n" /* w7, int16, out3 */ + "sub %[r1], %[r1], #16\n" /* r0 = r0 - 16 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w8-w9 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d3, d6\n" /* w8, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w8, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w8, int16, out2 */ + "vmull.s8 q7, d0, d6\n" /* w8, int16, out3 */ + "vmlal.s8 q4, d4, d7\n" /* w9, int16, out0 */ + "vmlal.s8 q5, d5, d7\n" /* w9, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w9, int16, out2 */ + "vmlal.s8 q7, d1, d7\n" /* w9, int16, out3 */ + "vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w10-w11 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */ + + /* inr2 */ + "vmull.s8 q4, d0, d6\n" /* w10, int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* w10, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w10, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w10, int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* w11, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w11, int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* w11, int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* w11, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w12-w13 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d1}, [%[r2]]\n" /* load r2, 6-7 */ + + "vmull.s8 q4, d2, d6\n" /* w12, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w12, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w12, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w12, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w13, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w13, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w13, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w13, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w14-w15 */ + "sub %[r2], %[r2], #16\n" /* r2 = r2 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w14, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w14, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w14, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w14, int16, out3 */ + "vld1.32 {d0-d3}, [%[r3]]!\n" /* load r3, 0-3 */ + /* inr3 */ + "vmlal.s8 q4, d0, d7\n" /* w15, int16, out0 */ + "vmlal.s8 q5, d1, d7\n" /* w15, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w15, int16, out2 */ + "vmlal.s8 q7, d3, d7\n" /* w15, int16, out3 */ + "vld1.32 {d4-d5}, [%[r3]]!\n" /* load r3, 4-5 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w16-w17 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w16, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w16, int16, out1 */ + "vmull.s8 q6, d3, d6\n" /* w16, int16, out2 */ + "vmull.s8 q7, d4, d6\n" /* w16, int16, out3 */ + "vld1.32 {d0-d1}, [%[r3]]\n" /* load r3, 6-7 */ + "vmlal.s8 q4, d2, d7\n" /* w17, int16, out0 */ + "vmlal.s8 q5, d3, d7\n" /* w17, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w17, int16, out2 */ + "vmlal.s8 q7, d5, d7\n" /* w17, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w18-w19 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "sub %[r3], %[r3], #16\n" /* r3 = r3 - 16 */ + + "vmull.s8 q4, d3, d6\n" /* w18, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w18, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w18, int16, out2 */ + "vmull.s8 q7, d0, d6\n" /* w18, int16, out3 */ + "vmlal.s8 q4, d4, d7\n" /* w19, int16, out0 */ + "vmlal.s8 q5, d5, d7\n" /* w19, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w19, int16, out2 */ + "vmlal.s8 q7, d1, d7\n" /* w19, int16, out3 */ + "vld1.32 {d0-d3}, [%[r4]]!\n" /* load r4, 0-3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w20-w21 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d4-d5}, [%[r4]]!\n" /* load r4, 4-5 */ + + /* inr4 */ + "vmull.s8 q4, d0, d6\n" /* w20, int16, out0 */ + "vmull.s8 q5, d1, d6\n" /* w20, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w20, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w20, int16, out3 */ + "vmlal.s8 q4, d1, d7\n" /* w21, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w21, int16, out1 */ + "vmlal.s8 q6, d3, d7\n" /* w21, int16, out2 */ + "vmlal.s8 q7, d4, d7\n" /* w21, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w22-w23 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d1}, [%[r4]]\n" /* load r4, 5-6 */ + + "vmull.s8 q4, d2, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d3, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w22, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w22, int16, out3 */ + "vmlal.s8 q4, d3, d7\n" /* w23, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w23, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w23, int16, out2 */ + "vmlal.s8 q7, d0, d7\n" /* w23, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w24 */ + "sub %[r4], %[r4], #16\n" /* r4 = r4 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "sub %[wptr], %[wptr], #200 \n" /* wptr = wptr - 200 */ + + "vmull.s8 q4, d4, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d5, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w22, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w22, int16, out3 */ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 0-3 */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n"/* store out0 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n"/*store out1 */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "subs %[cnt], #1\n" /* cnt = cnt - 1 */ + "vst1.32 {d24-d27}, [%[ptr_out0]]!\n"/* store out2 */ + "vst1.32 {d28-d31}, [%[ptr_out0]]!\n"/* store out3 */ + "bne 1b\n" /* branch main loop */ + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [ptr_out0] "+r"(ptr_out0), + [wptr] "+r"(wptr) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + // clang-format on + int32_t* ptr_tmp = ptr_out0 - w_loop * 32; + block_inr0 = block_inr1; + block_inr1 = block_inr2; + block_inr2 = block_inr3; + block_inr3 = block_inr4; + block_inr4 = block_inr3 + in_len; + } + write_int32_nchwc8_to_nchw(pre_out, + reinterpret_cast(dout_batch), + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + bias_local, + flag_bias, + ptr_write, + scale + c); } - // else if (od_type == AK_INT32) { - // write2_to_output_numc(pre_out, (int*)dout_batch, 1, hout_round, c, - // c+1, - // 0, hout, 0, wout_round, chout, hout, wout, flag_relu, - // (int*)ptr_write, scales); - // } - } /// end of chout - } /// end of batch num + } + } } -#endif // __aarch64__ - +template void conv_depthwise_5x5s1_int8(int8_t* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +template void conv_depthwise_5x5s1_int8(float* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index a9d6422942..b2d16d18d2 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din, } } -inline void prepack_input_nxw_c8_int8(const int8_t* din, - int8_t* dout, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int width, - int height) { +inline void prepack_input_nxwc8_int8_dw(const int8_t* din, + int8_t* dout, + int cs, + int hs, + int he, + int ws, + int we, + int channel, + int width, + int height) { int n = he - hs; if (n <= 0) { - LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0"; - return; + LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero"; } + int size_w = we - ws; int w0 = ws < 0 ? 0 : ws; int w1 = we > width ? width : we; - - int size_w = we - ws; - int size_channel_in = width * height; - int size_out_row = size_w * 8; - int valid_w = w1 - w0; - size_t valid_w_byte = valid_w * sizeof(int8_t); - - auto ptr_c = static_cast(TargetMalloc(TARGET(kARM), 8 * size_w)); - int8_t* ptr_r[8]; - int8_t* ptr_c_ori[8] = {ptr_c, - ptr_c + size_w, - ptr_c + 2 * size_w, - ptr_c + 3 * size_w, - ptr_c + 4 * size_w, - ptr_c + 5 * size_w, - ptr_c + 6 * size_w, - ptr_c + 7 * size_w}; + int pad_l = ws < 0 ? -ws : 0; + int pad_r = we > width ? we - width : 0; + int size_c = width * height; + + int valid_cnt = valid_w >> 3; + int remain = valid_w & 7; int8_t zero_ptr[size_w * 2]; // NOLINT memset(zero_ptr, 0, size_w * 2); - int loop = size_w / 8; - int remain = size_w - loop * 8; - - for (int c = cs; c < ce; c += 8) { - auto din_c = din + c * size_channel_in; - for (int j = 0; j < 8; ++j) { - ptr_r[j] = ptr_c_ori[j]; - } - //! valid channel - if (c + 8 > channel) { - switch (c + 8 - channel) { + for (int h = hs; h < he; ++h) { + const int8_t* ptr_c0 = din + h * width + cs * size_c; + const int8_t* ptr_c1 = ptr_c0 + size_c; + const int8_t* ptr_c2 = ptr_c1 + size_c; + const int8_t* ptr_c3 = ptr_c2 + size_c; + const int8_t* ptr_c4 = ptr_c3 + size_c; + const int8_t* ptr_c5 = ptr_c4 + size_c; + const int8_t* ptr_c6 = ptr_c5 + size_c; + const int8_t* ptr_c7 = ptr_c6 + size_c; + if (h < 0 || h >= height) { + memset(dout, 0, 8 * size_w * sizeof(int8_t)); + dout += size_w * 8; + continue; + } else if (cs + 8 > channel) { + switch (cs + 8 - channel) { case 7: - ptr_r[1] = zero_ptr; + ptr_c1 = zero_ptr; case 6: - ptr_r[2] = zero_ptr; + ptr_c2 = zero_ptr; case 5: - ptr_r[3] = zero_ptr; + ptr_c3 = zero_ptr; case 4: - ptr_r[4] = zero_ptr; + ptr_c4 = zero_ptr; case 3: - ptr_r[5] = zero_ptr; + ptr_c5 = zero_ptr; case 2: - ptr_r[6] = zero_ptr; + ptr_c6 = zero_ptr; case 1: - ptr_r[7] = zero_ptr; + ptr_c7 = zero_ptr; default: break; } } - //! valid height - int j = 0; - for (int i = hs; i < he; i++) { - auto din_r = din_c + i * width; - for (int k = 0; k < 8; ++k) { - if (ptr_r[k] != zero_ptr) { - if (i < 0 || i >= height) { - ptr_r[k] = zero_ptr + size_w; - } else { - ptr_r[k] = ptr_c_ori[k]; - auto ptr = ptr_r[k]; - for (int w = ws; w < w0; ++w) { - *(ptr++) = 0; - } - memcpy(ptr, din_r + k * size_channel_in, valid_w_byte); - ptr += valid_w; - for (int w = w1; w < we; ++w) { - *(ptr++) = 0; - } - } - } - } - int cnt = loop; - int8_t* inr0 = ptr_r[0]; - int8_t* inr1 = ptr_r[1]; - int8_t* inr2 = ptr_r[2]; - int8_t* inr3 = ptr_r[3]; - int8_t* inr4 = ptr_r[4]; - int8_t* inr5 = ptr_r[5]; - int8_t* inr6 = ptr_r[6]; - int8_t* inr7 = ptr_r[7]; - auto ptr_out = dout + j * size_out_row; - if (cnt > 0) { + if (pad_l) { + memset(dout, 0, pad_l * 8 * sizeof(int8_t)); + dout += pad_l * 8; + } + if (valid_cnt) { + int cnt = valid_cnt; #ifdef __aarch64__ - asm volatile( - /* main loop */ - "1:\n" - "ldr d0, [%[r0]], #8\n" - "ldr d1, [%[r1]], #8\n" - "ldr d2, [%[r2]], #8\n" - "ldr d3, [%[r3]], #8\n" - "ldr d4, [%[r4]], #8\n" - "ldr d5, [%[r5]], #8\n" - "ldr d6, [%[r6]], #8\n" - "ldr d7, [%[r7]], #8\n" - "trn1 v8.8b, v0.8b, v1.8b\n" - "trn2 v9.8b, v0.8b, v1.8b\n" - "trn1 v10.8b, v2.8b, v3.8b\n" - "trn2 v11.8b, v2.8b, v3.8b\n" - "trn1 v12.8b, v4.8b, v5.8b\n" - "trn2 v13.8b, v4.8b, v5.8b\n" - "trn1 v14.8b, v6.8b, v7.8b\n" - "trn2 v15.8b, v6.8b, v7.8b\n" - "trn1 v0.4h, v8.4h, v10.4h\n" - "trn2 v1.4h, v8.4h, v10.4h\n" - "trn1 v2.4h, v9.4h, v11.4h\n" - "trn2 v3.4h, v9.4h, v11.4h\n" - "trn1 v4.4h, v12.4h, v14.4h\n" - "trn2 v5.4h, v12.4h, v14.4h\n" - "trn1 v6.4h, v13.4h, v15.4h\n" - "trn2 v7.4h, v13.4h, v15.4h\n" - "trn1 v8.2s, v0.2s, v4.2s\n" - "trn1 v9.2s, v2.2s, v6.2s\n" - "trn1 v10.2s, v1.2s, v5.2s\n" - "trn1 v11.2s, v3.2s, v7.2s\n" - "stp d8, d9, [%[ptr_out]], #16\n" - "trn2 v12.2s, v0.2s, v4.2s\n" - "trn2 v13.2s, v2.2s, v6.2s\n" - "stp d10, d11, [%[ptr_out]], #16\n" - "trn2 v14.2s, v1.2s, v5.2s\n" - "trn2 v15.2s, v3.2s, v7.2s\n" - "subs %w[cnt], %w[cnt], #1\n" - "stp d12, d13, [%[ptr_out]], #16\n" - "stp d14, d15, [%[ptr_out]], #16\n" - "bne 1b\n" - : [cnt] "+r"(cnt), - [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [r5] "+r"(inr5), - [r6] "+r"(inr6), - [r7] "+r"(inr7), - [ptr_out] "+r"(ptr_out) - : - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); + asm volatile( + /* main loop */ + "1:\n" + "ldr d0, [%[r0]], #8\n" + "ldr d1, [%[r1]], #8\n" + "ldr d2, [%[r2]], #8\n" + "ldr d3, [%[r3]], #8\n" + "ldr d4, [%[r4]], #8\n" + "ldr d5, [%[r5]], #8\n" + "ldr d6, [%[r6]], #8\n" + "ldr d7, [%[r7]], #8\n" + "trn1 v8.8b, v0.8b, v1.8b\n" + "trn2 v9.8b, v0.8b, v1.8b\n" + "trn1 v10.8b, v2.8b, v3.8b\n" + "trn2 v11.8b, v2.8b, v3.8b\n" + "trn1 v12.8b, v4.8b, v5.8b\n" + "trn2 v13.8b, v4.8b, v5.8b\n" + "trn1 v14.8b, v6.8b, v7.8b\n" + "trn2 v15.8b, v6.8b, v7.8b\n" + "trn1 v0.4h, v8.4h, v10.4h\n" + "trn2 v1.4h, v8.4h, v10.4h\n" + "trn1 v2.4h, v9.4h, v11.4h\n" + "trn2 v3.4h, v9.4h, v11.4h\n" + "trn1 v4.4h, v12.4h, v14.4h\n" + "trn2 v5.4h, v12.4h, v14.4h\n" + "trn1 v6.4h, v13.4h, v15.4h\n" + "trn2 v7.4h, v13.4h, v15.4h\n" + "trn1 v8.2s, v0.2s, v4.2s\n" + "trn1 v9.2s, v2.2s, v6.2s\n" + "trn1 v10.2s, v1.2s, v5.2s\n" + "trn1 v11.2s, v3.2s, v7.2s\n" + "stp d8, d9, [%[ptr_out]], #16\n" + "trn2 v12.2s, v0.2s, v4.2s\n" + "trn2 v13.2s, v2.2s, v6.2s\n" + "stp d10, d11, [%[ptr_out]], #16\n" + "trn2 v14.2s, v1.2s, v5.2s\n" + "trn2 v15.2s, v3.2s, v7.2s\n" + "subs %w[cnt], %w[cnt], #1\n" + "stp d12, d13, [%[ptr_out]], #16\n" + "stp d14, d15, [%[ptr_out]], #16\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(ptr_c0), + [r1] "+r"(ptr_c1), + [r2] "+r"(ptr_c2), + [r3] "+r"(ptr_c3), + [r4] "+r"(ptr_c4), + [r5] "+r"(ptr_c5), + [r6] "+r"(ptr_c6), + [r7] "+r"(ptr_c7), + [ptr_out] "+r"(dout) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); #else - asm volatile( - /* main loop */ - "1:\n" - "vld1.32 {d0}, [%[r0]]!\n" - "vld1.32 {d1}, [%[r1]]!\n" - "vld1.32 {d2}, [%[r2]]!\n" - "vld1.32 {d3}, [%[r3]]!\n" - "vld1.32 {d4}, [%[r4]]!\n" - "vld1.32 {d5}, [%[r5]]!\n" - "vld1.32 {d6}, [%[r6]]!\n" - "vld1.32 {d7}, [%[r7]]!\n" - "vtrn.8 d0, d1\n" - "vtrn.8 d2, d3\n" - "vtrn.8 d4, d5\n" - "vtrn.8 d6, d7\n" - "vtrn.16 d0, d2\n" - "vtrn.16 d1, d3\n" - "vtrn.16 d4, d6\n" - "vtrn.16 d5, d7\n" - "vtrn.32 d0, d4\n" - "vtrn.32 d2, d6\n" - "vtrn.32 d1, d5\n" - "vtrn.32 d3, d7\n" - "subs %[cnt], #1\n" - "vst1.32 {d0-d3}, [%[ptr_out]]!\n" - "vst1.32 {d4-d7}, [%[ptr_out]]!\n" - "bne 1b\n" - : [cnt] "+r"(cnt), - [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [r5] "+r"(inr5), - [r6] "+r"(inr6), - [r7] "+r"(inr7), - [ptr_out] "+r"(ptr_out) - : - : "cc", "memory", "q0", "q1", "q2", "q3"); - -#endif // aarch64 - } - for (int k = 0; k < remain; ++k) { - ptr_out[0] = *(inr0++); - ptr_out[1] = *(inr1++); - ptr_out[2] = *(inr2++); - ptr_out[3] = *(inr3++); - ptr_out[4] = *(inr4++); - ptr_out[5] = *(inr5++); - ptr_out[6] = *(inr6++); - ptr_out[7] = *(inr7++); - ptr_out += 8; - } - j++; + asm volatile( + /* main loop */ + "1:\n" + "vld1.32 {d0}, [%[r0]]!\n" + "vld1.32 {d1}, [%[r1]]!\n" + "vld1.32 {d2}, [%[r2]]!\n" + "vld1.32 {d3}, [%[r3]]!\n" + "vld1.32 {d4}, [%[r4]]!\n" + "vld1.32 {d5}, [%[r5]]!\n" + "vld1.32 {d6}, [%[r6]]!\n" + "vld1.32 {d7}, [%[r7]]!\n" + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + "vtrn.32 d0, d4\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d3, d7\n" + "subs %[cnt], #1\n" + "vst1.32 {d0-d3}, [%[ptr_out]]!\n" + "vst1.32 {d4-d7}, [%[ptr_out]]!\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(ptr_c0), + [r1] "+r"(ptr_c1), + [r2] "+r"(ptr_c2), + [r3] "+r"(ptr_c3), + [r4] "+r"(ptr_c4), + [r5] "+r"(ptr_c5), + [r6] "+r"(ptr_c6), + [r7] "+r"(ptr_c7), + [ptr_out] "+r"(dout) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ + } + for (int i = 0; i < remain; ++i) { + dout[0] = *(ptr_c0++); + dout[1] = *(ptr_c1++); + dout[2] = *(ptr_c2++); + dout[3] = *(ptr_c3++); + dout[4] = *(ptr_c4++); + dout[5] = *(ptr_c5++); + dout[6] = *(ptr_c6++); + dout[7] = *(ptr_c7++); + dout += 8; + } + if (pad_r) { + memset(dout, 0, pad_r * 8 * sizeof(int8_t)); + dout += pad_r * 8; } } - TargetFree(TARGET(kARM), ptr_c); } /*wirte result in outputs diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index 00925d1f3c..53acdb46c7 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din, bool flag_relu, ARMContext* ctx); +template +void conv_depthwise_5x5s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 3817262b7c..1b81d6d5e3 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din, } } +void conv_depthwise_5x5_int8_fp32(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_5x5s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; + } +} + +void conv_depthwise_5x5_int8_int8(const void* din, + void* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const void* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const float* scale) { + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + if (stride == 1) { + conv_depthwise_5x5s1_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); + } else { + LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index b80130c94a..98007db0d1 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -97,7 +97,7 @@ void ConvCompute::PrepareForRun() { bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); - bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); + bool flag_dw_5x5 = (kw == 5 && sw == 1); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { @@ -136,7 +136,7 @@ void ConvCompute::PrepareForRun() { bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); - bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); + bool flag_dw_5x5 = (kw == 5 && sw == 1); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 9be3283c58..57c366aee0 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -31,7 +31,7 @@ void DepthwiseConv::PrepareForRun() { // select dw conv kernel if (kw == 3) { VLOG(5) << "invoke 3x3 dw conv fp32"; - /// trans weights + // trans weights constexpr int cblock = 4; auto oc = w_dims[0]; auto kh = w_dims[2]; @@ -75,6 +75,7 @@ void DepthwiseConv::PrepareForRun() { } /// select dw conv kernel if (kw == 3) { + // trans weights VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out"; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; int cround = ROUNDUP(w_dims[0], 8); @@ -83,6 +84,16 @@ void DepthwiseConv::PrepareForRun() { auto wptr_new = weights_.mutable_data(); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); flag_trans_weights_ = true; + } else if (kw == 5) { + // trans weights + VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out"; + impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25); + flag_trans_weights_ = true; } else { LOG(FATAL) << "this type dw conv not impl"; } @@ -123,6 +134,7 @@ void DepthwiseConv::PrepareForRun() { } /// select dw conv kernel if (kw == 3) { + // trans weights VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out"; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; int cround = ROUNDUP(w_dims[0], 8); @@ -131,6 +143,16 @@ void DepthwiseConv::PrepareForRun() { auto wptr_new = weights_.mutable_data(); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); flag_trans_weights_ = true; + } else if (kw == 5) { + // trans weights + VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out"; + impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; + int cround = ROUNDUP(w_dims[0], 8); + weights_.Resize({cround / 8, 1, kh * kw, 8}); + auto wptr = param.filter->data(); + auto wptr_new = weights_.mutable_data(); + lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25); + flag_trans_weights_ = true; } else { LOG(FATAL) << "this type dw conv not impl"; } diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index b2ad011f8c..2a0971a298 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { } #endif /// 3x3dw -#if 0 /// 5x5dw +#if 1 /// 5x5dw TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { if (FLAGS_basic_test) { - for (auto& stride : {1, 2}) { + for (auto& stride : {1}) { for (auto& pad : {0, 1, 2}) { for (auto& flag_bias : {false, true}) { for (auto& flag_relu : {false, true}) { @@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { std::vector dims; DDim weights_dim({c, 1, 5, 5}); for (auto& batch : {1, 2}) { - for (auto &h : {1, 3, 15, 19, 28, 32, 75}) { + for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { dims.push_back(DDim({batch, c, h, h})); } } -- GitLab