diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 2990f7a0f8d4712a3dc3c429d9b57e5aa3809325..44621ba99a92a3ed456b8d7d0959e3580662d910 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3379,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, // 对 B 分块 NC = L1 / (KC * sizeof(float)); if (NC == 0) { - NC == NR; + NC = NR; } else { int nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index adc6924d8ad273012a9b44677f8ad1a29bc37787..ea023bc134033aee6577ebf06c95f2a762d08bca 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -22,9 +22,11 @@ limitations under the License. */ #define C(i, j) C[(i)*ldc + (j)] #if __aarch64__ +#define MR_INT8 4 #define MR 6 #define NR 16 #else +#define MR_INT8 4 #define MR 6 #define NR 8 #endif @@ -189,6 +191,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits function cluster begins // 8 bits int small block inner product + void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); @@ -199,6 +203,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int8_t *bias); // 8 bits int pack function + void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index bd5286dbcb5c871d5d327875b836ad9777c270bf..5dd8a7c3131543f426f32e258efb3181be9b2f61 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -26,11 +26,228 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { +void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 2; + int32_t kc4 = kc2 & 3; + int32_t kc5 = kc4 >> 1; + int32_t kc6 = kc4 & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q8-q15: save 32 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vmov.s32 q8, #0 \n\t" + "vmov.s32 q9, q8 \n\t" + "vmov.s32 q10, q8 \n\t" + "vmov.s32 q11, q8 \n\t" + "vmov.s32 q12, q8 \n\t" + "vmov.s32 q13, q8 \n\t" + "vmov.s32 q14, q8 \n\t" + "vmov.s32 q15, q8 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #128] \n\t" + "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows + "vmovl.s8 q2, d0 \n\t" // process B first 4 + // rows + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vld1.s8 {d12-d15}, [%[b_ptr]]! \n\t" // load B second 4 + // rows + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" // process B second 4 + // rows + "vmovl.s8 q3, d12 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d13 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d3 \n\t" + "vmovl.s8 q3, d14 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d15 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 0b \n\t" + "1: \n\t" // last 4 rows + "subs %[kc3], %[kc3], #1 \n\t" + "blt 2f \n\t" + "vld1.s8 {d0-d1}, [%[a_ptr]]! \n\t" // load A 4 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B 4 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "2: \n\t" // last 2 rows + "subs %[kc5], %[kc5], #1 \n\t" + "blt 3f \n\t" + "vld1.s8 {d0}, [%[a_ptr]]! \n\t" // load A 2 cols + "vld1.s8 {d8-d9}, [%[b_ptr]]! \n\t" // load B 2 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "3: \n\t" // last 1 row + "subs %[kc6], %[kc6], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" // load A 1 col + "vld1.s8 {d8}, [%[b_ptr]] \n\t" // load B 1 row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "4: \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#endif // __ARM_NEON +} // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON +#if __aarch64__ +// TODO +#else const int8_t *a_ptr, *b_ptr; a_ptr = a; b_ptr = b; @@ -46,383 +263,265 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, "pld [%[a_ptr]] \n\t" "pld [%[b_ptr]] \n\t" "pld [%[b_ptr], #64] \n\t" - "vmov.s8 q4, #0 \n\t" - "vmov.s8 q5, #0 \n\t" - "vmov.s8 q6, #0 \n\t" - "vmov.s8 q7, #0 \n\t" - "vmov.s8 q8, #0 \n\t" - "vmov.s8 q9, #0 \n\t" - "vmov.s8 q10, #0 \n\t" - "vmov.s8 q11, #0 \n\t" - "vmov.s8 q12, #0 \n\t" - "vmov.s8 q13, #0 \n\t" - "vmov.s8 q14, #0 \n\t" - "vmov.s8 q15, #0 \n\t" + "vmov.s32 q4, #0 \n\t" + "vmov.s32 q5, q4 \n\t" + "vmov.s32 q6, q4 \n\t" + "vmov.s32 q7, q4 \n\t" + "vmov.s32 q8, q4 \n\t" + "vmov.s32 q9, q4 \n\t" + "vmov.s32 q10, q4 \n\t" + "vmov.s32 q11, q4 \n\t" + "vmov.s32 q12, q4 \n\t" + "vmov.s32 q13, q4 \n\t" + "vmov.s32 q14, q4 \n\t" + "vmov.s32 q15, q4 \n\t" "mov r0, #12 \n\t" - "subs %[kc1], %[kc1], #1 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" "blt 1f \n\t" "0: \n\t" "pld [%[a_ptr], #64] \n\t" "pld [%[b_ptr], #128] \n\t" - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, - // 1/2 q3 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vdup.s8 d3, d0[0] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d0[6] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d3, d1[4] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d2[2] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, - // 1/2 q3 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vdup.s8 d3, d0[0] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d0[6] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, - // q1 - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d3, d1[4] \n\t" // q3 used // used - "vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 - "vdup.s8 d3, d2[2] \n\t" // q3 used - "vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, - // q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" "subs %[kc1], %[kc1], #1 \n\t" "bge 0b \n\t" "1: \n\t" // last <8 rows "subs %[kc3], %[kc3], #1 \n\t" "blt 2f \n\t" - "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" - "vmov.s8 q2, #0 \n\t" - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" - "vdup.s8 d3, d0[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d0[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[2] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[0] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[3] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[1] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d0[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d1[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[4] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[2] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[5] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[3] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[6] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[4] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d1[7] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[5] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[0] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[6] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d3, d2[1] \n\t" - "vmlal.s8 q2, d6, d3 \n\t" - "vdup.s8 d3, d2[7] \n\t" - "vmlal.s8 q2, d7, d3 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[0]\n\t" + "vmlal.s16 q5, d7, d5[0]\n\t" + "vmlal.s16 q6, d6, d5[1]\n\t" + "vmlal.s16 q7, d7, d5[1]\n\t" + "vmlal.s16 q8, d6, d5[2]\n\t" + "vmlal.s16 q9, d7, d5[2]\n\t" + "vmlal.s16 q10, d6, d5[3]\n\t" + "vmlal.s16 q11, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" + "vmlal.s16 q12, d6, d4[0]\n\t" + "vmlal.s16 q13, d7, d4[0]\n\t" + "vmlal.s16 q14, d6, d4[1]\n\t" + "vmlal.s16 q15, d7, d4[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[2]\n\t" + "vmlal.s16 q5, d7, d4[2]\n\t" + "vmlal.s16 q6, d6, d4[3]\n\t" + "vmlal.s16 q7, d7, d4[3]\n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" "2: \n\t" // last <4 rows "subs %[kc5], %[kc5], #1 \n\t" "blt 3f \n\t" "vld1.s8 {d0, d1}, [%[a_ptr]], r0 \n\t" - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[0] \n\t" - "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" - "vdup.s8 d7, d0[6] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[1] \n\t" - "vdup.s8 d7, d0[7] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[2] \n\t" - "vdup.s8 d7, d1[0] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[3] \n\t" - "vdup.s8 d7, d1[1] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0. \n\t" - "vdup.s8 d6, d0[4] \n\t" - "vdup.s8 d7, d1[2] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[5] \n\t" - "vdup.s8 d7, d1[3] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vmlal.s8 q2, d3, d7 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" + "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d5[2]\n\t" + "vmlal.s16 q5, d7, d5[2]\n\t" + "vmlal.s16 q6, d6, d5[3]\n\t" + "vmlal.s16 q7, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" "3: \n\t" // last <2 rows "subs %[kc6], %[kc6], #1 \n\t" "blt 4f \n\t" "vld1.s8 {d0}, [%[a_ptr]] \n\t" - "vld1.s8 {d1}, [%[b_ptr]] \n\t" - "vdup.s8 d2, d0[0] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vdup.s8 d2, d0[1] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vdup.s8 d2, d0[2] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vdup.s8 d2, d0[3] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vdup.s8 d2, d0[4] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vdup.s8 d2, d0[5] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "vld1.s8 {d3}, [%[b_ptr]] \n\t" + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d3 \n\t" + "vmlal.s16 q4, d6, d4[0]\n\t" + "vmlal.s16 q5, d7, d4[0]\n\t" + "vmlal.s16 q6, d6, d4[1]\n\t" + "vmlal.s16 q7, d7, d4[1]\n\t" + "vmlal.s16 q8, d6, d4[2]\n\t" + "vmlal.s16 q9, d7, d4[2]\n\t" + "vmlal.s16 q10, d6, d4[3]\n\t" + "vmlal.s16 q11, d7, d4[3]\n\t" + "vmlal.s16 q12, d6, d5[0]\n\t" + "vmlal.s16 q13, d7, d5[0]\n\t" + "vmlal.s16 q14, d6, d5[1]\n\t" + "vmlal.s16 q15, d7, d5[1]\n\t" "4: \n\t" "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" @@ -435,7 +534,8 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif +#endif // __aarch64__ +#endif // __ARM_NEON } // 8 bits int inner product @@ -445,8 +545,9 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, int8_t *bias) { #pragma omp parallel for for (int32_t j = 0; j < nc; j += NR) { - for (int32_t i = 0; i < mc; i += MR) { - AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + for (int32_t i = 0; i < mc; i += MR_INT8) { + // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); } } if (alpha != 1) { @@ -474,12 +575,53 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, return; } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int8_t *a0, *a1, *a2, *a3; + for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + for (int32_t j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } + + if (m_tail != 0) { + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } +} -// 8 bits int PackMatrixA +// 8 bits int PackMatrixA_6r void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR) { + for (int32_t i = 0; i < i_length; i += MR_INT8) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -539,6 +681,9 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); #if __ARM_NEON +#if __aarch64__ + // TODO +#else asm volatile( // "pld [%[b0]] \n\t" "vld1.s8 {d0}, [%[b0]] \n\t" @@ -546,6 +691,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0"); +#endif // __aarch64__ #else *local_buffer++ = *b0++; *local_buffer++ = *b0++; @@ -585,13 +731,13 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, MC = L1 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t)); - // make sure MC is multiple of MR, and NC is multiple of NR + // make sure MC is multiple of MR_INT8, and NC is multiple of NR if (MC == 0) { - MC = MR; + MC = MR_INT8; } else { int32_t mblock_num = (m + MC - 1) / MC; MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; } // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; if (NC == 0) { @@ -618,7 +764,8 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); + // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); if (bias == nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, packedC_int8, &C(i, j), ldc, relu, nullptr); @@ -642,6 +789,10 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, // C = A * B, 8位 int32_t void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; int32_t step = sizeof(int32_t) * ldc; @@ -695,6 +846,8 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, } } } +#endif // __aarch64__ +#endif // __ARM_NEON } // C = A * B + C diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 678add6dcedd22e788e0bd2df64a8eba59ad8514..10dab2cda1b3c692f42cf8760eb2b48ae6451f39 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "../test_helper.h" #include "../test_include.h" #include "operators/mul_op.h"