From 90d08d2d42bf6370d01db245475f58e48e173dc7 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Fri, 27 Jul 2018 10:31:57 +0800 Subject: [PATCH] optimize gemm --- src/operators/math/gemm.cpp | 318 +++++++++++++++++++++++++++++++++--- src/operators/math/gemm.h | 13 +- 2 files changed, 302 insertions(+), 29 deletions(-) diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 4966ca1459..b9b61f4d1c 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -92,8 +92,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, */ // 将A矩阵分块复制到连续内存(RowMajor) -void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { +void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { const float *a0, *a1, *a2, *a3; for (int i = 0; i < m - m_tail; i += MR) { a0 = A + i * lda; @@ -131,9 +131,62 @@ void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, } } +void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const float *a0, *a1, *a2, *a3, *a4, *a5; + for (int i = 0; i < m - m_tail; i += MR) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + a4 = A + (i + 4) * lda; + a5 = A + (i + 5) * lda; + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + *buffer++ = *a4++; + *buffer++ = *a5++; + } + } + int i = m - m_tail; + a0 = &A(i, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + a4 = a0 + 4 * lda; + a5 = a0 + 5 * lda; + if (m_tail != 0) { + if (m_tail <= 5) { + a5 = zero; + } + if (m_tail <= 4) { + a4 = zero; + } + if (m_tail <= 3) { + a3 = zero; + } + if (m_tail <= 2) { + a2 = zero; + } + if (m_tail <= 1) { + a1 = zero; + } + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + *buffer++ = *a4++; + *buffer++ = *a5++; + } + } +} + // 将B矩阵分块复制到连续内存(RowMajor) -void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { +void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { const float *b0; for (int j = 0; j < n - n_tail; j += NR) { for (int i = 0; i < k; ++i) { @@ -188,7 +241,8 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, for (int j = 0; j < nc; j += NR) { for (int i = 0; i < mc; i += MR) { // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); } } @@ -218,7 +272,8 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, for (int j = 0; j < nc; j += NR) { for (int i = 0; i < mc; i += MR) { // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); } } @@ -1868,22 +1923,22 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 30 * 1024; - int L2 = 1 * 1024 * 1024; + int L1 = 32 * 1024; + int L2 = 0.5 * 1024 * 1024; KC = k; - MC = L2 / (2 * KC * sizeof(float)); - NC = MC; + MC = L1 / (KC * sizeof(float)); + NC = L2 / (KC * sizeof(float)); - // make sure MC is multiple of 4, and NC is multiple of 8 + // make sure MC is multiple of MR, and NC is multiple of NR int mblock_num = (m + MC - 1) / MC; MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + 4 - 1) / 4 * 4; + MC = (MC + MR - 1) / MR * MR; // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; int nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + 8 - 1) / 8 * 8; + NC = (NC + NR - 1) / NR * NR; // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; packedA = static_cast( @@ -1901,10 +1956,10 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); - PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, relu); } @@ -1921,22 +1976,22 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, bool relu, float *new_scale, float *new_bias) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 30 * 1024; - int L2 = 1 * 1024 * 1024; + int L1 = 32 * 1024; + int L2 = 0.5 * 1024 * 1024; KC = k; - MC = L2 / (2 * KC * sizeof(float)); - NC = MC; + MC = L1 / (KC * sizeof(float)); + NC = L2 / (KC * sizeof(float)); - // make sure MC is multiple of 4, and NC is multiple of 8 + // make sure MC is multiple of MR, and NC is multiple of NR int mblock_num = (m + MC - 1) / MC; MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + 4 - 1) / 4 * 4; + MC = (MC + MR - 1) / MR * MR; // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; int nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + 8 - 1) / 8 * 8; + NC = (NC + NR - 1) / NR * NR; // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; packedA = static_cast( @@ -1954,10 +2009,10 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); - PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, relu, new_scale + i, new_bias + i); } @@ -1969,6 +2024,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(zero); } +void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { +#if __ARM_NEON +#if __aarch64__ + + // init C + float32x4_t cv0 = vdupq_n_f32(0.0); + float32x4_t cv1 = vdupq_n_f32(0.0); + float32x4_t cv2 = vdupq_n_f32(0.0); + float32x4_t cv3 = vdupq_n_f32(0.0); + float32x4_t cv4 = vdupq_n_f32(0.0); + float32x4_t cv5 = vdupq_n_f32(0.0); + float32x4_t cv6 = vdupq_n_f32(0.0); + float32x4_t cv7 = vdupq_n_f32(0.0); + float32x4_t cv8 = vdupq_n_f32(0.0); + float32x4_t cv9 = vdupq_n_f32(0.0); + float32x4_t cv10 = vdupq_n_f32(0.0); + float32x4_t cv11 = vdupq_n_f32(0.0); + + float32x4_t av; + float32x4_t bv0; + float32x4_t bv1; + + float32x2_t av01; + float32x2_t av23; + float32x2_t av45; + + for (int p = 0; p < k; p += 1) { + av = vld1q_f32(a); + av01 = vget_low_f32(av); + av23 = vget_high_f32(av); + av45 = vld1_f32(a + 4); + bv0 = vld1q_f32(b); + bv1 = vld1q_f32(b + 4); + + cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0); + cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0); + cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1); + cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1); + + cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0); + cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0); + cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1); + cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1); + + cv8 = vmlaq_lane_f32(cv8, bv0, av45, 0); + cv9 = vmlaq_lane_f32(cv9, bv1, av45, 0); + cv10 = vmlaq_lane_f32(cv10, bv0, av45, 1); + cv11 = vmlaq_lane_f32(cv11, bv1, av45, 1); + + a += MR; + b += NR; + } + + vst1q_f32(c, cv0); + vst1q_f32(c + 4, cv1); + vst1q_f32(c + ldc, cv2); + vst1q_f32(c + ldc + 4, cv3); + vst1q_f32(c + 2 * ldc, cv4); + vst1q_f32(c + 2 * ldc + 4, cv5); + vst1q_f32(c + 3 * ldc, cv6); + vst1q_f32(c + 3 * ldc + 4, cv7); + vst1q_f32(c + 4 * ldc, cv8); + vst1q_f32(c + 4 * ldc + 4, cv9); + vst1q_f32(c + 5 * ldc, cv10); + vst1q_f32(c + 5 * ldc + 4, cv11); + +#else + + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k / 4; + int kc2 = k % 4; + int step = 4 * ldc; + asm volatile( + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #64] \n\t" + + "vmov.f32 q4, #0.0 \n\t" + "vmov.f32 q5, #0.0 \n\t" + "vmov.f32 q6, #0.0 \n\t" + "vmov.f32 q7, #0.0 \n\t" + "vmov.f32 q8, #0.0 \n\t" + "vmov.f32 q9, #0.0 \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "vmov.f32 q14, #0.0 \n\t" + "vmov.f32 q15, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt end_kc1_%= \n\t" + "loop_kc1_%=: \n\t" + + // "pld [%[a_ptr], #128] \n\t" + // "pld [%[b_ptr], #128] \n\t" + // "pld [%[a_ptr], #192] \n\t" + // "pld [%[b_ptr], #192] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge loop_kc1_%= \n\t" + "end_kc1_%=: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt end_kc2_%= \n\t" + "loop_kc2_%=: \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "bge loop_kc2_%= \n\t" + "end_kc2_%=: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q4, q5}, [r5], r6 \n\t" + "vst1.32 {q6, q7}, [r5], r6 \n\t" + "vst1.32 {q8, q9}, [r5], r6 \n\t" + "vst1.32 {q10, q11}, [r5], r6 \n\t" + "vst1.32 {q12, q13}, [r5], r6 \n\t" + "vst1.32 {q14, q15}, [r5] \n\t" + + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + +#endif // __aarch64__ +#else + +#endif // __ARM_NEON +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index d8b305a728..2044c264ed 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -19,7 +19,7 @@ limitations under the License. */ #define B(i, j) B[(i)*ldb + (j)] #define C(i, j) C[(i)*ldc + (j)] -#define MR 4 +#define MR 6 #define NR 8 #define s_min(i, j) ((i) < (j) ? (i) : (j)) @@ -39,12 +39,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, */ // 将 A 矩阵分块复制到连续内存(RowMajor) -void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda, - float *buffer); +void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); +void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); // 将 B 矩阵分块复制到连续内存(RowMajor) -void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); +void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, @@ -67,6 +69,7 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, // 计算一个更小的 C 矩阵分块 void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc); +void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); // 分块矩阵乘法结果回写 // C = A * B -- GitLab