diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index b9b61f4d1c59a0e2c8e7822742c54472ad540981..20d71907ff9e391d97ce75e38b6e08dc1286a9a3 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -107,20 +107,22 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, *buffer++ = *a3++; } } - int i = m - m_tail; - a0 = &A(i, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; + if (m_tail != 0) { - if (m_tail <= 3) { - a3 = zero; - } - if (m_tail <= 2) { - a2 = zero; - } - if (m_tail <= 1) { - a1 = zero; + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + break; + default: + break; } for (int j = 0; j < k; ++j) { *buffer++ = *a0++; @@ -150,28 +152,89 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, *buffer++ = *a5++; } } - int i = m - m_tail; - a0 = &A(i, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - a4 = a0 + 4 * lda; - a5 = a0 + 5 * lda; if (m_tail != 0) { - if (m_tail <= 5) { - a5 = zero; + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + a4 = a0 + 4 * lda; + a5 = a0 + 5 * lda; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + break; + default: + break; } - if (m_tail <= 4) { - a4 = zero; - } - if (m_tail <= 3) { - a3 = zero; + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + *buffer++ = *a4++; + *buffer++ = *a5++; } - if (m_tail <= 2) { - a2 = zero; + } +} + +void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + for (int i = 0; i < m - m_tail; i += MR) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + a4 = A + (i + 4) * lda; + a5 = A + (i + 5) * lda; + a6 = A + (i + 6) * lda; + a7 = A + (i + 7) * lda; + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + *buffer++ = *a4++; + *buffer++ = *a5++; + *buffer++ = *a6++; + *buffer++ = *a7++; } - if (m_tail <= 1) { - a1 = zero; + } + if (m_tail != 0) { + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + a4 = a0 + 4 * lda; + a5 = a0 + 5 * lda; + a6 = a0 + 6 * lda; + a7 = a0 + 7 * lda; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + case 6: + a6 = zero; + case 7: + a7 = zero; + break; + default: + break; } for (int j = 0; j < k; ++j) { *buffer++ = *a0++; @@ -180,6 +243,8 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, *buffer++ = *a3++; *buffer++ = *a4++; *buffer++ = *a5++; + *buffer++ = *a6++; + *buffer++ = *a7++; } } } @@ -234,15 +299,78 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, } } +#if __aarch64__ +void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const float *b0; + for (int j = 0; j < n - n_tail; j += NR) { + for (int i = 0; i < k; ++i) { + b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" + : [buffer] "+r"(buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2"); + } + } + if (n_tail != 0) { + for (int i = 0; i < k; ++i) { + b0 = &B(i, n - n_tail); + for (int j = n - n_tail; j < n; ++j) { + *buffer++ = *b0++; + } + for (int j = n; j < n + (NR - n_tail); ++j) { + *buffer++ = 0; + } + } + } +} + +void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const float *b0; + for (int j = 0; j < n - n_tail; j += NR) { + for (int i = 0; i < k; ++i) { + b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" + : [buffer] "+r"(buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2", "v3"); + } + } + if (n_tail != 0) { + for (int i = 0; i < k; ++i) { + b0 = &B(i, n - n_tail); + for (int j = n - n_tail; j < n; ++j) { + *buffer++ = *b0++; + } + for (int j = n; j < n + (NR - n_tail); ++j) { + *buffer++ = 0; + } + } + } +} +#endif // __aarch64__ + // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, bool relu) { #pragma omp parallel for for (int j = 0; j < nc; j += NR) { for (int i = 0; i < mc; i += MR) { +#if __aarch64__ + // AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#else // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif } } @@ -271,9 +399,14 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, #pragma omp parallel for for (int j = 0; j < nc; j += NR) { for (int i = 0; i < mc; i += MR) { +#if __aarch64__ + // AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#else // AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif } } @@ -1956,10 +2089,20 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); +#if __aarch64__ + // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#else PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); +#if __aarch64__ PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#else + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#endif InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, relu); } @@ -2009,10 +2152,20 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); +#if __aarch64__ + // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#else PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); +#if __aarch64__ PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#else + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#endif InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, relu, new_scale + i, new_bias + i); } @@ -2239,6 +2392,192 @@ void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { #endif // __ARM_NEON } +#if __aarch64__ +void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k; + int step = 4 * ldc; + asm volatile( + "dup v5.4s, wzr \n\t" + "dup v6.4s, wzr \n\t" + "dup v7.4s, wzr \n\t" + "dup v8.4s, wzr \n\t" + "dup v9.4s, wzr \n\t" + "dup v10.4s, wzr \n\t" + "dup v11.4s, wzr \n\t" + "dup v12.4s, wzr \n\t" + "dup v13.4s, wzr \n\t" + "dup v14.4s, wzr \n\t" + "dup v15.4s, wzr \n\t" + "dup v16.4s, wzr \n\t" + + "dup v17.4s, wzr \n\t" + "dup v18.4s, wzr \n\t" + "dup v19.4s, wzr \n\t" + "dup v20.4s, wzr \n\t" + "dup v21.4s, wzr \n\t" + "dup v22.4s, wzr \n\t" + "dup v23.4s, wzr \n\t" + "dup v24.4s, wzr \n\t" + "dup v25.4s, wzr \n\t" + "dup v26.4s, wzr \n\t" + "dup v27.4s, wzr \n\t" + "dup v28.4s, wzr \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt end_kc1_%= \n\t" + "loop_kc1_%=: \n\t" + + "prfm pldl1keep, [%[a_ptr], #32] \n\t" + "prfm pldl1keep, [%[b_ptr], #48] \n\t" + + "ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t" + + "fmla v5.4s, v2.4s, v0.s[0] \n\t" + "fmla v6.4s, v3.4s, v0.s[0] \n\t" + "fmla v7.4s, v4.4s, v0.s[0] \n\t" + "fmla v8.4s, v2.4s, v0.s[1] \n\t" + "fmla v9.4s, v3.4s, v0.s[1] \n\t" + "fmla v10.4s, v4.4s, v0.s[1] \n\t" + "fmla v11.4s, v2.4s, v0.s[2] \n\t" + "fmla v12.4s, v3.4s, v0.s[2] \n\t" + "fmla v13.4s, v4.4s, v0.s[2] \n\t" + "fmla v14.4s, v2.4s, v0.s[3] \n\t" + "fmla v15.4s, v3.4s, v0.s[3] \n\t" + "fmla v16.4s, v4.4s, v0.s[3] \n\t" + + "fmla v17.4s, v2.4s, v1.s[0] \n\t" + "fmla v18.4s, v3.4s, v1.s[0] \n\t" + "fmla v19.4s, v4.4s, v1.s[0] \n\t" + "fmla v20.4s, v2.4s, v1.s[1] \n\t" + "fmla v21.4s, v3.4s, v1.s[1] \n\t" + "fmla v22.4s, v4.4s, v1.s[1] \n\t" + "fmla v23.4s, v2.4s, v1.s[2] \n\t" + "fmla v24.4s, v3.4s, v1.s[2] \n\t" + "fmla v25.4s, v4.4s, v1.s[2] \n\t" + "fmla v26.4s, v2.4s, v1.s[3] \n\t" + "fmla v27.4s, v3.4s, v1.s[3] \n\t" + "fmla v28.4s, v4.4s, v1.s[3] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge loop_kc1_%= \n\t" + "end_kc1_%=: \n\t" + + "st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t" + "st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t" + "st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" + "st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t" + "st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t" + "st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t" + "st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" + "st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [step] "r"(step) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28"); +} + +void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k; + int step = 4 * ldc; + int step1 = 4 * 6; + asm volatile( + + "dup v6.4s, wzr \n\t" + "dup v7.4s, wzr \n\t" + "dup v8.4s, wzr \n\t" + "dup v9.4s, wzr \n\t" + "dup v10.4s, wzr \n\t" + "dup v11.4s, wzr \n\t" + "dup v12.4s, wzr \n\t" + "dup v13.4s, wzr \n\t" + + "dup v14.4s, wzr \n\t" + "dup v15.4s, wzr \n\t" + "dup v16.4s, wzr \n\t" + "dup v17.4s, wzr \n\t" + "dup v18.4s, wzr \n\t" + "dup v19.4s, wzr \n\t" + "dup v20.4s, wzr \n\t" + "dup v21.4s, wzr \n\t" + + "dup v22.4s, wzr \n\t" + "dup v23.4s, wzr \n\t" + "dup v24.4s, wzr \n\t" + "dup v25.4s, wzr \n\t" + "dup v26.4s, wzr \n\t" + "dup v27.4s, wzr \n\t" + "dup v28.4s, wzr \n\t" + "dup v29.4s, wzr \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt end_kc1_%= \n\t" + "loop_kc1_%=: \n\t" + + "prfm pldl1keep, [%[a_ptr], #24] \n\t" + "prfm pldl1keep, [%[b_ptr], #64] \n\t" + + "ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t" + "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t" + + "fmla v6.4s, v2.4s, v0.s[0] \n\t" + "fmla v7.4s, v3.4s, v0.s[0] \n\t" + "fmla v8.4s, v4.4s, v0.s[0] \n\t" + "fmla v9.4s, v5.4s, v0.s[0] \n\t" + + "fmla v10.4s, v2.4s, v0.s[1] \n\t" + "fmla v11.4s, v3.4s, v0.s[1] \n\t" + "fmla v12.4s, v4.4s, v0.s[1] \n\t" + "fmla v13.4s, v5.4s, v0.s[1] \n\t" + + "fmla v14.4s, v2.4s, v0.s[2] \n\t" + "fmla v15.4s, v3.4s, v0.s[2] \n\t" + "fmla v16.4s, v4.4s, v0.s[2] \n\t" + "fmla v17.4s, v5.4s, v0.s[2] \n\t" + + "fmla v18.4s, v2.4s, v0.s[3] \n\t" + "fmla v19.4s, v3.4s, v0.s[3] \n\t" + "fmla v20.4s, v4.4s, v0.s[3] \n\t" + "fmla v21.4s, v5.4s, v0.s[3] \n\t" + + "fmla v22.4s, v2.4s, v1.s[0] \n\t" + "fmla v23.4s, v3.4s, v1.s[0] \n\t" + "fmla v24.4s, v4.4s, v1.s[0] \n\t" + "fmla v25.4s, v5.4s, v1.s[0] \n\t" + + "fmla v26.4s, v2.4s, v1.s[1] \n\t" + "fmla v27.4s, v3.4s, v1.s[1] \n\t" + "fmla v28.4s, v4.4s, v1.s[1] \n\t" + "fmla v29.4s, v5.4s, v1.s[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge loop_kc1_%= \n\t" + "end_kc1_%=: \n\t" + + "st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t" + "st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" + "st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t" + "st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t" + "st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" + "st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"); +} + +#endif // __aarch64__ + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 2044c264ed1c0f8624690874ed248661a753804c..a9593b15ae73f46aa287028ba74efdb0d303fdde 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -19,8 +19,13 @@ limitations under the License. */ #define B(i, j) B[(i)*ldb + (j)] #define C(i, j) C[(i)*ldc + (j)] +#if __aarch64__ +#define MR 6 +#define NR 16 +#else #define MR 6 #define NR 8 +#endif #define s_min(i, j) ((i) < (j) ? (i) : (j)) @@ -43,10 +48,16 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, float *buffer); void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer); +void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); // 将 B 矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); +void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, @@ -70,6 +81,8 @@ void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); +void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); +void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); // 分块矩阵乘法结果回写 // C = A * B @@ -114,10 +127,6 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias); -// 64位 double 矩阵乘法 -void dgemm(int m, int n, int k, float alpha, const double *A, int lda, - const double *B, int ldb, float beta, double *C, int ldc); - } // namespace math } // namespace operators } // namespace paddle_mobile