未验证 提交 3b1206f4 编写于 作者: R Ray Liu 提交者: GitHub

Merge branch 'develop' into develop_python_develop

...@@ -27,9 +27,6 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { ...@@ -27,9 +27,6 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) {
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_per_div_after_alignment = int num_per_div_after_alignment =
align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT); align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT);
if (num_per_div_before_alignment == num_per_div_after_alignment) {
return;
}
int num_element = int num_element =
2 * div_num * num_per_div_after_alignment; // including bias & scale 2 * div_num * num_per_div_after_alignment; // including bias & scale
float *ptr_aligned = float *ptr_aligned =
......
...@@ -3379,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, ...@@ -3379,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
// 对 B 分块 // 对 B 分块
NC = L1 / (KC * sizeof(float)); NC = L1 / (KC * sizeof(float));
if (NC == 0) { if (NC == 0) {
NC == NR; NC = NR;
} else { } else {
int nblock_num = (n + NC - 1) / NC; int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num; NC = (n + nblock_num - 1) / nblock_num;
......
...@@ -22,9 +22,11 @@ limitations under the License. */ ...@@ -22,9 +22,11 @@ limitations under the License. */
#define C(i, j) C[(i)*ldc + (j)] #define C(i, j) C[(i)*ldc + (j)]
#if __aarch64__ #if __aarch64__
#define MR_INT8 4
#define MR 6 #define MR 6
#define NR 16 #define NR 16
#else #else
#define MR_INT8 4
#define MR 6 #define MR 6
#define NR 8 #define NR 8
#endif #endif
...@@ -189,6 +191,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -189,6 +191,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits function cluster begins // 8 bits function cluster begins
// 8 bits int small block inner product // 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc); int32_t ldc);
...@@ -199,6 +203,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -199,6 +203,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
int8_t *bias); int8_t *bias);
// 8 bits int pack function // 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer); int32_t lda, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
......
...@@ -26,11 +26,228 @@ limitations under the License. */ ...@@ -26,11 +26,228 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO
#else
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int32_t kc1 = k >> 3;
int32_t kc2 = k & 7;
int32_t kc3 = kc2 >> 2;
int32_t kc4 = kc2 & 3;
int32_t kc5 = kc4 >> 1;
int32_t kc6 = kc4 & 1;
int32_t step = sizeof(int32_t) * ldc;
asm volatile(
// q8-q15: save 32 results
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vmov.s32 q8, #0 \n\t"
"vmov.s32 q9, q8 \n\t"
"vmov.s32 q10, q8 \n\t"
"vmov.s32 q11, q8 \n\t"
"vmov.s32 q12, q8 \n\t"
"vmov.s32 q13, q8 \n\t"
"vmov.s32 q14, q8 \n\t"
"vmov.s32 q15, q8 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t"
"0: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols
"vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows
"vmovl.s8 q2, d0 \n\t" // process B first 4
// rows
"vmovl.s8 q3, d8 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d9 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"vld1.s8 {d12-d15}, [%[b_ptr]]! \n\t" // load B second 4
// rows
"vmovl.s8 q2, d1 \n\t"
"vmovl.s8 q3, d10 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d11 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"vmovl.s8 q2, d2 \n\t" // process B second 4
// rows
"vmovl.s8 q3, d12 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d13 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"vmovl.s8 q2, d3 \n\t"
"vmovl.s8 q3, d14 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d15 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 0b \n\t"
"1: \n\t" // last 4 rows
"subs %[kc3], %[kc3], #1 \n\t"
"blt 2f \n\t"
"vld1.s8 {d0-d1}, [%[a_ptr]]! \n\t" // load A 4 cols
"vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B 4 rows
"vmovl.s8 q2, d0 \n\t"
"vmovl.s8 q3, d8 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d9 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"vmovl.s8 q2, d1 \n\t"
"vmovl.s8 q3, d10 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d11 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"2: \n\t" // last 2 rows
"subs %[kc5], %[kc5], #1 \n\t"
"blt 3f \n\t"
"vld1.s8 {d0}, [%[a_ptr]]! \n\t" // load A 2 cols
"vld1.s8 {d8-d9}, [%[b_ptr]]! \n\t" // load B 2 rows
"vmovl.s8 q2, d0 \n\t"
"vmovl.s8 q3, d8 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"vmovl.s8 q3, d9 \n\t"
"vmlal.s16 q8, d6, d5[0]\n\t"
"vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s16 q13, d7, d5[2]\n\t"
"vmlal.s16 q14, d6, d5[3]\n\t"
"vmlal.s16 q15, d7, d5[3]\n\t"
"3: \n\t" // last 1 row
"subs %[kc6], %[kc6], #1 \n\t"
"blt 4f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t" // load A 1 col
"vld1.s8 {d8}, [%[b_ptr]] \n\t" // load B 1 row
"vmovl.s8 q2, d0 \n\t"
"vmovl.s8 q3, d8 \n\t"
"vmlal.s16 q8, d6, d4[0]\n\t"
"vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s16 q15, d7, d4[3]\n\t"
"4: \n\t"
"vst1.32 {q8, q9}, [%[c]], %[step] \n\t"
"vst1.32 {q10, q11}, [%[c]], %[step] \n\t"
"vst1.32 {q12, q13}, [%[c]], %[step] \n\t"
"vst1.32 {q14, q15}, [%[c]] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
#endif // __ARM_NEON
}
// 8 bits int small block inner product // 8 bits int small block inner product
void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__
// TODO
#else
const int8_t *a_ptr, *b_ptr; const int8_t *a_ptr, *b_ptr;
a_ptr = a; a_ptr = a;
b_ptr = b; b_ptr = b;
...@@ -46,383 +263,265 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -46,383 +263,265 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"pld [%[a_ptr]] \n\t" "pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t" "pld [%[b_ptr]] \n\t"
"pld [%[b_ptr], #64] \n\t" "pld [%[b_ptr], #64] \n\t"
"vmov.s8 q4, #0 \n\t" "vmov.s32 q4, #0 \n\t"
"vmov.s8 q5, #0 \n\t" "vmov.s32 q5, q4 \n\t"
"vmov.s8 q6, #0 \n\t" "vmov.s32 q6, q4 \n\t"
"vmov.s8 q7, #0 \n\t" "vmov.s32 q7, q4 \n\t"
"vmov.s8 q8, #0 \n\t" "vmov.s32 q8, q4 \n\t"
"vmov.s8 q9, #0 \n\t" "vmov.s32 q9, q4 \n\t"
"vmov.s8 q10, #0 \n\t" "vmov.s32 q10, q4 \n\t"
"vmov.s8 q11, #0 \n\t" "vmov.s32 q11, q4 \n\t"
"vmov.s8 q12, #0 \n\t" "vmov.s32 q12, q4 \n\t"
"vmov.s8 q13, #0 \n\t" "vmov.s32 q13, q4 \n\t"
"vmov.s8 q14, #0 \n\t" "vmov.s32 q14, q4 \n\t"
"vmov.s8 q15, #0 \n\t" "vmov.s32 q15, q4 \n\t"
"mov r0, #12 \n\t" "mov r0, #12 \n\t"
"subs %[kc1], %[kc1], #1 \n\t" "subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t" "blt 1f \n\t"
"0: \n\t" "0: \n\t"
"pld [%[a_ptr], #64] \n\t" "pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #128] \n\t" "pld [%[b_ptr], #128] \n\t"
"vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols
// 1/2 q3 used "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row
"vmov.s8 q2, #0 \n\t" // q2 used "vmovl.s8 q2, d0 \n\t"
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, "vmovl.s8 q3, d3 \n\t"
// q1 "vmlal.s16 q4, d6, d4[0]\n\t"
"vdup.s8 d3, d0[0] \n\t" // q3 used // used "vmlal.s16 q5, d7, d4[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 "vmlal.s16 q6, d6, d4[1]\n\t"
"vdup.s8 d3, d0[6] \n\t" // q3 used "vmlal.s16 q7, d7, d4[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, "vmlal.s16 q8, d6, d4[2]\n\t"
// q3 free "vmlal.s16 q9, d7, d4[2]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q10, d6, d4[3]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q11, d7, d4[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q12, d6, d5[0]\n\t"
"vdup.s8 d3, d0[1] \n\t" "vmlal.s16 q13, d7, d5[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q14, d6, d5[1]\n\t"
"vdup.s8 d3, d0[7] \n\t" "vmlal.s16 q15, d7, d5[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row
"vaddw.s16 q6, q6, d4 \n\t" "vmovl.s8 q3, d3 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q4, d6, d5[2]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q5, d7, d5[2]\n\t"
"vdup.s8 d3, d0[2] \n\t" "vmlal.s16 q6, d6, d5[3]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q7, d7, d5[3]\n\t"
"vdup.s8 d3, d1[0] \n\t" "vmovl.s8 q2, d1 \n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q9, d7, d4[0]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmlal.s16 q10, d6, d4[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q11, d7, d4[1]\n\t"
"vdup.s8 d3, d0[3] \n\t" "vmlal.s16 q12, d6, d4[2]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q13, d7, d4[2]\n\t"
"vdup.s8 d3, d1[1] \n\t" "vmlal.s16 q14, d6, d4[3]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q15, d7, d4[3]\n\t"
"vaddw.s16 q10, q10, d4 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row
"vaddw.s16 q11, q11, d5 \n\t" // res row 3 "vmovl.s8 q3, d3 \n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q4, d6, d5[0]\n\t"
"vdup.s8 d3, d0[4] \n\t" "vmlal.s16 q5, d7, d5[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q6, d6, d5[1]\n\t"
"vdup.s8 d3, d1[2] \n\t" "vmlal.s16 q7, d7, d5[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q8, d6, d5[2]\n\t"
"vaddw.s16 q12, q12, d4 \n\t" "vmlal.s16 q9, d7, d5[2]\n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4 "vmlal.s16 q10, d6, d5[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q11, d7, d5[3]\n\t"
"vdup.s8 d3, d0[5] \n\t" "vmovl.s8 q2, d2 \n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q12, d6, d4[0]\n\t"
"vdup.s8 d3, d1[3] \n\t" "vmlal.s16 q13, d7, d4[0]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q14, d6, d4[1]\n\t"
"vaddw.s16 q14, q14, d4 \n\t" "vmlal.s16 q15, d7, d4[1]\n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5 "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row
"vmovl.s8 q3, d3 \n\t"
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, "vmlal.s16 q4, d6, d4[2]\n\t"
// q1 "vmlal.s16 q5, d7, d4[2]\n\t"
"vmov.s8 q2, #0 \n\t" // q2 used "vmlal.s16 q6, d6, d4[3]\n\t"
"vdup.s8 d3, d1[4] \n\t" // q3 used // used "vmlal.s16 q7, d7, d4[3]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 "vmlal.s16 q8, d6, d5[0]\n\t"
"vdup.s8 d3, d2[2] \n\t" // q3 used "vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, "vmlal.s16 q10, d6, d5[1]\n\t"
// q3 free "vmlal.s16 q11, d7, d5[1]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q12, d6, d5[2]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q13, d7, d5[2]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q14, d6, d5[3]\n\t"
"vdup.s8 d3, d1[5] \n\t" "vmlal.s16 q15, d7, d5[3]\n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[3] \n\t" "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols
"vmlal.s8 q2, d7, d3 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row
"vaddw.s16 q6, q6, d4 \n\t" "vmovl.s8 q2, d0 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmovl.s8 q3, d3 \n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q4, d6, d4[0]\n\t"
"vdup.s8 d3, d1[6] \n\t" "vmlal.s16 q5, d7, d4[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q6, d6, d4[1]\n\t"
"vdup.s8 d3, d2[4] \n\t" "vmlal.s16 q7, d7, d4[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q8, d6, d4[2]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q9, d7, d4[2]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmlal.s16 q10, d6, d4[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q11, d7, d4[3]\n\t"
"vdup.s8 d3, d1[7] \n\t" "vmlal.s16 q12, d6, d5[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q13, d7, d5[0]\n\t"
"vdup.s8 d3, d2[5] \n\t" "vmlal.s16 q14, d6, d5[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q15, d7, d5[1]\n\t"
"vaddw.s16 q10, q10, d4 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row
"vaddw.s16 q11, q11, d5 \n\t" // res row 3 "vmovl.s8 q3, d3 \n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q4, d6, d5[2]\n\t"
"vdup.s8 d3, d2[0] \n\t" "vmlal.s16 q5, d7, d5[2]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q6, d6, d5[3]\n\t"
"vdup.s8 d3, d2[6] \n\t" "vmlal.s16 q7, d7, d5[3]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmovl.s8 q2, d1 \n\t"
"vaddw.s16 q12, q12, d4 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4 "vmlal.s16 q9, d7, d4[0]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q10, d6, d4[1]\n\t"
"vdup.s8 d3, d2[1] \n\t" "vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q12, d6, d4[2]\n\t"
"vdup.s8 d3, d2[7] \n\t" "vmlal.s16 q13, d7, d4[2]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q14, d6, d4[3]\n\t"
"vaddw.s16 q14, q14, d4 \n\t" "vmlal.s16 q15, d7, d4[3]\n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5 "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row
"vmovl.s8 q3, d3 \n\t"
"vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols, q0 used, "vmlal.s16 q4, d6, d5[0]\n\t"
// 1/2 q3 used "vmlal.s16 q5, d7, d5[0]\n\t"
"vmov.s8 q2, #0 \n\t" // q2 used "vmlal.s16 q6, d6, d5[1]\n\t"
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1, "vmlal.s16 q7, d7, d5[1]\n\t"
// q1 "vmlal.s16 q8, d6, d5[2]\n\t"
"vdup.s8 d3, d0[0] \n\t" // q3 used // used "vmlal.s16 q9, d7, d5[2]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0 "vmlal.s16 q10, d6, d5[3]\n\t"
"vdup.s8 d3, d0[6] \n\t" // q3 used "vmlal.s16 q11, d7, d5[3]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1, "vmovl.s8 q2, d2 \n\t"
// q3 free "vmlal.s16 q12, d6, d4[0]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q13, d7, d4[0]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q14, d6, d4[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q15, d7, d4[1]\n\t"
"vdup.s8 d3, d0[1] \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row
"vmlal.s8 q2, d6, d3 \n\t" "vmovl.s8 q3, d3 \n\t"
"vdup.s8 d3, d0[7] \n\t" "vmlal.s16 q4, d6, d4[2]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q5, d7, d4[2]\n\t"
"vaddw.s16 q6, q6, d4 \n\t" "vmlal.s16 q6, d6, d4[3]\n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q7, d7, d4[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q8, d6, d5[0]\n\t"
"vdup.s8 d3, d0[2] \n\t" "vmlal.s16 q9, d7, d5[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q10, d6, d5[1]\n\t"
"vdup.s8 d3, d1[0] \n\t" "vmlal.s16 q11, d7, d5[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q12, d6, d5[2]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q13, d7, d5[2]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmlal.s16 q14, d6, d5[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q15, d7, d5[3]\n\t"
"vdup.s8 d3, d0[3] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d1[1] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d0[4] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d1[2] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d0[5] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d1[3] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" // B 2 rows, B row1,
// q1
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d3, d1[4] \n\t" // q3 used // used
"vmlal.s8 q2, d6, d3 \n\t" // A col00 * B row0
"vdup.s8 d3, d2[2] \n\t" // q3 used
"vmlal.s8 q2, d7, d3 \n\t" // A col10 * B row1,
// q3 free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d1[5] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[3] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d1[6] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[4] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d1[7] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[5] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d2[0] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[6] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d2[1] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[7] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t" "subs %[kc1], %[kc1], #1 \n\t"
"bge 0b \n\t" "bge 0b \n\t"
"1: \n\t" // last <8 rows "1: \n\t" // last <8 rows
"subs %[kc3], %[kc3], #1 \n\t" "subs %[kc3], %[kc3], #1 \n\t"
"blt 2f \n\t" "blt 2f \n\t"
"vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" "vld1.s8 {d0-d2}, [%[a_ptr]]! \n\t" // A 4 cols
"vmov.s8 q2, #0 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" "vmovl.s8 q2, d0 \n\t"
"vdup.s8 d3, d0[0] \n\t" "vmovl.s8 q3, d3 \n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q4, d6, d4[0]\n\t"
"vdup.s8 d3, d0[6] \n\t" "vmlal.s16 q5, d7, d4[0]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q6, d6, d4[1]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q7, d7, d4[1]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q8, d6, d4[2]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q9, d7, d4[2]\n\t"
"vdup.s8 d3, d0[1] \n\t" "vmlal.s16 q10, d6, d4[3]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q11, d7, d4[3]\n\t"
"vdup.s8 d3, d0[7] \n\t" "vmlal.s16 q12, d6, d5[0]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q13, d7, d5[0]\n\t"
"vaddw.s16 q6, q6, d4 \n\t" "vmlal.s16 q14, d6, d5[1]\n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q15, d7, d5[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row
"vdup.s8 d3, d0[2] \n\t" "vmovl.s8 q3, d3 \n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q4, d6, d5[2]\n\t"
"vdup.s8 d3, d1[0] \n\t" "vmlal.s16 q5, d7, d5[2]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q6, d6, d5[3]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q7, d7, d5[3]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmovl.s8 q2, d1 \n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t"
"vdup.s8 d3, d0[3] \n\t" "vmlal.s16 q9, d7, d4[0]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q10, d6, d4[1]\n\t"
"vdup.s8 d3, d1[1] \n\t" "vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q12, d6, d4[2]\n\t"
"vaddw.s16 q10, q10, d4 \n\t" "vmlal.s16 q13, d7, d4[2]\n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3 "vmlal.s16 q14, d6, d4[3]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q15, d7, d4[3]\n\t"
"vdup.s8 d3, d0[4] \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 3th row
"vmlal.s8 q2, d6, d3 \n\t" "vmovl.s8 q3, d3 \n\t"
"vdup.s8 d3, d1[2] \n\t" "vmlal.s16 q4, d6, d5[0]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q5, d7, d5[0]\n\t"
"vaddw.s16 q12, q12, d4 \n\t" "vmlal.s16 q6, d6, d5[1]\n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4 "vmlal.s16 q7, d7, d5[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q8, d6, d5[2]\n\t"
"vdup.s8 d3, d0[5] \n\t" "vmlal.s16 q9, d7, d5[2]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q10, d6, d5[3]\n\t"
"vdup.s8 d3, d1[3] \n\t" "vmlal.s16 q11, d7, d5[3]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmovl.s8 q2, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t" "vmlal.s16 q12, d6, d4[0]\n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5 "vmlal.s16 q13, d7, d4[0]\n\t"
"vmlal.s16 q14, d6, d4[1]\n\t"
"vld1.s8 {d6-d7}, [%[b_ptr]]! \n\t" "vmlal.s16 q15, d7, d4[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 4th row
"vdup.s8 d3, d1[4] \n\t" "vmovl.s8 q3, d3 \n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q4, d6, d4[2]\n\t"
"vdup.s8 d3, d2[2] \n\t" "vmlal.s16 q5, d7, d4[2]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q6, d6, d4[3]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q7, d7, d4[3]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q8, d6, d5[0]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q9, d7, d5[0]\n\t"
"vdup.s8 d3, d1[5] \n\t" "vmlal.s16 q10, d6, d5[1]\n\t"
"vmlal.s8 q2, d6, d3 \n\t" "vmlal.s16 q11, d7, d5[1]\n\t"
"vdup.s8 d3, d2[3] \n\t" "vmlal.s16 q12, d6, d5[2]\n\t"
"vmlal.s8 q2, d7, d3 \n\t" "vmlal.s16 q13, d7, d5[2]\n\t"
"vaddw.s16 q6, q6, d4 \n\t" "vmlal.s16 q14, d6, d5[3]\n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q15, d7, d5[3]\n\t"
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d1[6] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[4] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d1[7] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[5] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d2[0] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[6] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d3, d2[1] \n\t"
"vmlal.s8 q2, d6, d3 \n\t"
"vdup.s8 d3, d2[7] \n\t"
"vmlal.s8 q2, d7, d3 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"2: \n\t" // last <4 rows "2: \n\t" // last <4 rows
"subs %[kc5], %[kc5], #1 \n\t" "subs %[kc5], %[kc5], #1 \n\t"
"blt 3f \n\t" "blt 3f \n\t"
"vld1.s8 {d0, d1}, [%[a_ptr]], r0 \n\t" "vld1.s8 {d0, d1}, [%[a_ptr]], r0 \n\t"
"vmov.s8 q2, #0 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 1st row
"vdup.s8 d6, d0[0] \n\t" "vmovl.s8 q2, d0 \n\t"
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" "vmovl.s8 q3, d3 \n\t"
"vdup.s8 d7, d0[6] \n\t" "vmlal.s16 q4, d6, d4[0]\n\t"
"vmlal.s8 q2, d2, d6 \n\t" "vmlal.s16 q5, d7, d4[0]\n\t"
"vmlal.s8 q2, d3, d7 \n\t" "vmlal.s16 q6, d6, d4[1]\n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q7, d7, d4[1]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q8, d6, d4[2]\n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q9, d7, d4[2]\n\t"
"vdup.s8 d6, d0[1] \n\t" "vmlal.s16 q10, d6, d4[3]\n\t"
"vdup.s8 d7, d0[7] \n\t" "vmlal.s16 q11, d7, d4[3]\n\t"
"vmlal.s8 q2, d2, d6 \n\t" "vmlal.s16 q12, d6, d5[0]\n\t"
"vmlal.s8 q2, d3, d7 \n\t" "vmlal.s16 q13, d7, d5[0]\n\t"
"vaddw.s16 q6, q6, d4 \n\t" "vmlal.s16 q14, d6, d5[1]\n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q15, d7, d5[1]\n\t"
"vmov.s8 q2, #0 \n\t" "vld1.s8 {d3}, [%[b_ptr]]! \n\t" // B 2nd row
"vdup.s8 d6, d0[2] \n\t" "vmovl.s8 q3, d3 \n\t"
"vdup.s8 d7, d1[0] \n\t" "vmlal.s16 q4, d6, d5[2]\n\t"
"vmlal.s8 q2, d2, d6 \n\t" "vmlal.s16 q5, d7, d5[2]\n\t"
"vmlal.s8 q2, d3, d7 \n\t" "vmlal.s16 q6, d6, d5[3]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q7, d7, d5[3]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmovl.s8 q2, d1 \n\t"
"vmov.s8 q2, #0 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t"
"vdup.s8 d6, d0[3] \n\t" "vmlal.s16 q9, d7, d4[0]\n\t"
"vdup.s8 d7, d1[1] \n\t" "vmlal.s16 q10, d6, d4[1]\n\t"
"vmlal.s8 q2, d2, d6 \n\t" "vmlal.s16 q11, d7, d4[1]\n\t"
"vmlal.s8 q2, d3, d7 \n\t" "vmlal.s16 q12, d6, d4[2]\n\t"
"vaddw.s16 q10, q10, d4 \n\t" "vmlal.s16 q13, d7, d4[2]\n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3 "vmlal.s16 q14, d6, d4[3]\n\t"
"vmov.s8 q2, #0. \n\t" "vmlal.s16 q15, d7, d4[3]\n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"3: \n\t" // last <2 rows "3: \n\t" // last <2 rows
"subs %[kc6], %[kc6], #1 \n\t" "subs %[kc6], %[kc6], #1 \n\t"
"blt 4f \n\t" "blt 4f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t" "vld1.s8 {d0}, [%[a_ptr]] \n\t"
"vld1.s8 {d1}, [%[b_ptr]] \n\t" "vld1.s8 {d3}, [%[b_ptr]] \n\t"
"vdup.s8 d2, d0[0] \n\t" "vmovl.s8 q2, d0 \n\t"
"vmull.s8 q2, d1, d2 \n\t" "vmovl.s8 q3, d3 \n\t"
"vaddw.s16 q4, q4, d4 \n\t" "vmlal.s16 q4, d6, d4[0]\n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vmlal.s16 q5, d7, d4[0]\n\t"
"vdup.s8 d2, d0[1] \n\t" "vmlal.s16 q6, d6, d4[1]\n\t"
"vmull.s8 q2, d1, d2 \n\t" "vmlal.s16 q7, d7, d4[1]\n\t"
"vaddw.s16 q6, q6, d4 \n\t" "vmlal.s16 q8, d6, d4[2]\n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1 "vmlal.s16 q9, d7, d4[2]\n\t"
"vdup.s8 d2, d0[2] \n\t" "vmlal.s16 q10, d6, d4[3]\n\t"
"vmull.s8 q2, d1, d2 \n\t" "vmlal.s16 q11, d7, d4[3]\n\t"
"vaddw.s16 q8, q8, d4 \n\t" "vmlal.s16 q12, d6, d5[0]\n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2 "vmlal.s16 q13, d7, d5[0]\n\t"
"vdup.s8 d2, d0[3] \n\t" "vmlal.s16 q14, d6, d5[1]\n\t"
"vmull.s8 q2, d1, d2 \n\t" "vmlal.s16 q15, d7, d5[1]\n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vdup.s8 d2, d0[4] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vdup.s8 d2, d0[5] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 4
"4: \n\t" "4: \n\t"
"vst1.32 {q4, q5}, [%[c]], %[step] \n\t" "vst1.32 {q4, q5}, [%[c]], %[step] \n\t"
"vst1.32 {q6, q7}, [%[c]], %[step] \n\t" "vst1.32 {q6, q7}, [%[c]], %[step] \n\t"
...@@ -435,7 +534,8 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -435,7 +534,8 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
[kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step)
: "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif #endif // __aarch64__
#endif // __ARM_NEON
} }
// 8 bits int inner product // 8 bits int inner product
...@@ -445,8 +545,9 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, ...@@ -445,8 +545,9 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
int8_t *bias) { int8_t *bias) {
#pragma omp parallel for #pragma omp parallel for
for (int32_t j = 0; j < nc; j += NR) { for (int32_t j = 0; j < nc; j += NR) {
for (int32_t i = 0; i < mc; i += MR) { for (int32_t i = 0; i < mc; i += MR_INT8) {
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
} }
} }
if (alpha != 1) { if (alpha != 1) {
...@@ -474,12 +575,53 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, ...@@ -474,12 +575,53 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
return; return;
} }
} }
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int8_t *a0, *a1, *a2, *a3;
for (int32_t i = 0; i < m - m_tail; i += MR_INT8) {
a0 = A + i * lda;
a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda;
for (int32_t j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
if (m_tail != 0) {
a0 = &A(m - m_tail, 0);
a1 = a0 + lda;
a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*buffer++ = *a0++;
*buffer++ = *a1++;
*buffer++ = *a2++;
*buffer++ = *a3++;
}
}
}
// 8 bits int PackMatrixA // 8 bits int PackMatrixA_6r
void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) { int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail; const int32_t i_length = m - m_tail;
for (int32_t i = 0; i < i_length; i += MR) { for (int32_t i = 0; i < i_length; i += MR_INT8) {
const int8_t *a0 = A + i * lda; const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda; const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda; const int8_t *a2 = A + (i + 2) * lda;
...@@ -539,6 +681,9 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, ...@@ -539,6 +681,9 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
for (int32_t i = 0; i < k; ++i) { for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j); const int8_t *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__
// TODO
#else
asm volatile( asm volatile(
// "pld [%[b0]] \n\t" // "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t" "vld1.s8 {d0}, [%[b0]] \n\t"
...@@ -546,6 +691,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, ...@@ -546,6 +691,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
: [local_buffer] "+r"(local_buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0"); : "memory", "q0");
#endif // __aarch64__
#else #else
*local_buffer++ = *b0++; *local_buffer++ = *b0++;
*local_buffer++ = *b0++; *local_buffer++ = *b0++;
...@@ -585,13 +731,13 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, ...@@ -585,13 +731,13 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
MC = L1 / (KC * sizeof(int8_t)); MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR, and NC is multiple of NR // make sure MC is multiple of MR_INT8, and NC is multiple of NR
if (MC == 0) { if (MC == 0) {
MC = MR; MC = MR_INT8;
} else { } else {
int32_t mblock_num = (m + MC - 1) / MC; int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num; MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR; MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
} }
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) { if (NC == 0) {
...@@ -618,7 +764,8 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, ...@@ -618,7 +764,8 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8);
for (int32_t i = 0; i < m; i += MC) { for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) { if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, nullptr); packedC_int8, &C(i, j), ldc, relu, nullptr);
...@@ -642,6 +789,10 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, ...@@ -642,6 +789,10 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
// C = A * B, 8位 int32_t // C = A * B, 8位 int32_t
void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO
#else
int32_t nc1 = nc >> 4; int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15; int32_t _nc1 = nc & 15;
int32_t step = sizeof(int32_t) * ldc; int32_t step = sizeof(int32_t) * ldc;
...@@ -695,6 +846,8 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, ...@@ -695,6 +846,8 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
} }
} }
} }
#endif // __aarch64__
#endif // __ARM_NEON
} }
// C = A * B + C // C = A * B + C
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint-gcc.h>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#include "operators/mul_op.h" #include "operators/mul_op.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册