diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc index 677490502e643fae3bc8149933e9936880711f0d..8087e0337bda0866f5d399a07ecb674f0fa55a3e 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.cc +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -930,89 +930,105 @@ void sgemm_prepack_c4_small(int M, b += 4; } #else - for (; n > 5; n -= 6) { + for (; n > 7; n -= 8) { int cnt = kcnt; const float* a_ptr = A_packed; const float* b_ptr = b; + // clang-format off asm volatile( - "vld1.32 {d8-d9}, [%[bias]] \n" + "vld1.32 {d6-d7}, [%[bias]] \n" /* load a0, a1 */ - "vld1.32 {d12-d15}, [%[a]]! \n" + "vld1.32 {d8-d11}, [%[a]]! \n" /* mov bias to c0-c7*/ - "vmov.u32 q10, q4 \n" - "vmov.u32 q11, q4 \n" - "vmov.u32 q12, q4 \n" - /* load b0-b3 */ - "vld1.32 {d0-d3}, [%[b]]!\n" - "vld1.32 {d4-d7}, [%[b]]!\n" - "vmov.u32 q13, q4 \n" - "vmov.u32 q14, q4 \n" - "vmov.u32 q15, q4 \n" + "vmov.u32 q8, q3 \n" + "vmov.u32 q9, q3 \n" + "vmov.u32 q10, q3 \n" + "vmov.u32 q11, q3 \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmov.u32 q12, q3 \n" + "vmov.u32 q13, q3 \n" + "vmov.u32 q14, q3 \n" + "vmov.u32 q15, q3 \n" "1:\n" - /* load b4, b5 */ - "vld1.32 {d8-d11}, [%[b]]! \n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" /* load a2, a3 */ - "vld1.32 {d16-d19}, [%[a]]!\n" - "vmla.f32 q10, q6, d0[0] \n" - "vmla.f32 q11, q6, d2[0] \n" - "vmla.f32 q12, q6, d4[0] \n" - "vmla.f32 q13, q6, d6[0] \n" - "vmla.f32 q14, q6, d8[0] \n" - "vmla.f32 q15, q6, d10[0] \n" - "sub %[b], %[b], #96 \n" - "vmla.f32 q10, q7, d0[1] \n" - "vmla.f32 q11, q7, d2[1] \n" - "vmla.f32 q12, q7, d4[1] \n" - "vmla.f32 q13, q7, d6[1] \n" - "vmla.f32 q14, q7, d8[1] \n" - "vmla.f32 q15, q7, d10[1] \n" - "add %[b], %[b], %[ldb] \n" - "pld [%[b]] \n" + "vld1.32 {d12-d15}, [%[a]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q12, q4, d0[0] \n" + "vmla.f32 q13, q4, d2[0] \n" + "vmla.f32 q14, q4, d4[0] \n" + "vmla.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" /* load a0, a1 */ - "vld1.32 {d12-d15}, [%[a]]!\n" - "vmla.f32 q10, q8, d1[0] \n" - "vmla.f32 q11, q8, d3[0] \n" - "vmla.f32 q12, q8, d5[0] \n" - "vmla.f32 q13, q8, d7[0] \n" - "pld [%[b], #64] \n" - "vmla.f32 q10, q9, d1[1] \n" - "vmla.f32 q11, q9, d3[1] \n" + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" /* load b0, b1 */ - "vld1.32 {d0-d3}, [%[b]]! \n" - "vmla.f32 q14, q8, d9[0] \n" - "vmla.f32 q15, q8, d11[0] \n" - "vmla.f32 q12, q9, d5[1] \n" - "vmla.f32 q13, q9, d7[1] \n" - "vmla.f32 q14, q9, d9[1] \n" - "vmla.f32 q15, q9, d11[1] \n" - /* load b2, b3 */ - "vld1.32 {d4-d7}, [%[b]]! \n" - "subs %[cnt], %[cnt], #1 \n" - "bne 1b \n" - "cmp %[relu], #0 \n" - "beq 2f \n" - "vmov.u32 q0, #0 \n" - "vmax.f32 q10, q10, q0 \n" - "vmax.f32 q11, q11, q0 \n" - "vmax.f32 q12, q12, q0 \n" - "vmax.f32 q13, q13, q0 \n" - "vmax.f32 q14, q14, q0 \n" - "vmax.f32 q15, q15, q0 \n" - "2: \n" - "vst1.32 {d20-d23}, [%[c]]! \n" - "vst1.32 {d24-d27}, [%[c]]! \n" - "vst1.32 {d28-d31}, [%[c]]! \n" + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "bne 1b \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "vmax.f32 q12, q12, q0 \n" + "vmax.f32 q13, q13, q0 \n" + "vmax.f32 q14, q14, q0 \n" + "vmax.f32 q15, q15, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" : [a] "+r" (a_ptr), [b] "+r" (b_ptr), [c] "+r" (C), [cnt] "+r" (cnt) : [relu] "r" (has_relu), [ldb] "r" (ldb_byte), - [bias] "r"(bias_ptr) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory" + [bias] "r" (bias_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "memory" ); - b += 4 * 6; + b += 4 * 8; } for (; n > 3; n -= 4) { int cnt = kcnt; @@ -1071,7 +1087,7 @@ void sgemm_prepack_c4_small(int M, [cnt] "+r" (cnt) : [relu] "r" (has_relu), [ldb] "r" (ldb_byte), - [bias] "r"(bias_ptr) + [bias] "r" (bias_ptr) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "cc", "memory" @@ -1117,7 +1133,7 @@ void sgemm_prepack_c4_small(int M, [cnt] "+r" (cnt) : [relu] "r" (has_relu), [ldb] "r" (ldb_byte), - [bias] "r"(bias_ptr) + [bias] "r" (bias_ptr) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "cc", "memory" );