diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 815db53d037f25285871f603de248970ac4cb4e8..3730cf350a1399e5f3c1473fd1ce8d7b1d13b1b6 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -33,6 +33,14 @@ float *packedA; float *packedB; float *packedC; float *zero; + +typedef void (*FnPack)(int, int, int, const float *, int, float *); +typedef void (*FnAddDot)(int, const float *, const float *, float *, int); + +FnPack procPackA; +FnPack procPackB; +FnAddDot procAddDot; + /* // 将A矩阵分块复制到连续内存(ColMajor) void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, @@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer) { - const float *a0, *a1, *a2, *a3, *a4, *a5; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - a4 = A + (i + 4) * lda; - a5 = A + (i + 5) * lda; + const int i_length = m - m_tail; + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *local_buffer = buffer + i * k; for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; } } if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - a4 = a0 + 4 * lda; - a5 = a0 + 5 * lda; + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *local_buffer = buffer + i_length * k; switch (m_tail) { case 1: a1 = zero; @@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, break; } for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const int i_length = m - m_tail; +#pragma omp parallel for + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *local_buffer = buffer + i * k; + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; } } } void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, float *buffer) { - const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - a4 = A + (i + 4) * lda; - a5 = A + (i + 5) * lda; - a6 = A + (i + 6) * lda; - a7 = A + (i + 7) * lda; + const int i_length = m - m_tail; + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + const float *a6 = A + (i + 6) * lda; + const float *a7 = A + (i + 7) * lda; + float *local_buffer = buffer + i * k; for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; - *buffer++ = *a6++; - *buffer++ = *a7++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; } } if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - a4 = a0 + 4 * lda; - a5 = a0 + 5 * lda; - a6 = a0 + 6 * lda; - a7 = a0 + 7 * lda; + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + const float *a6 = a0 + 6 * lda; + const float *a7 = a0 + 7 * lda; + float *local_buffer = buffer + i_length * k; switch (m_tail) { case 1: a1 = zero; @@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, break; } for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - *buffer++ = *a4++; - *buffer++ = *a5++; - *buffer++ = *a6++; - *buffer++ = *a7++; + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; + } + } +} + +void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer) { + const int i_length = m - m_tail; +#pragma omp parallel for + for (int i = 0; i < i_length; i += MR) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + const float *a6 = A + (i + 6) * lda; + const float *a7 = A + (i + 7) * lda; + float *local_buffer = buffer + i * k; + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; + } + } + if (m_tail != 0) { + const float *a0 = &A(i_length, 0); + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + const float *a6 = a0 + 6 * lda; + const float *a7 = a0 + 7 * lda; + float *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero; + case 2: + a2 = zero; + case 3: + a3 = zero; + case 4: + a4 = zero; + case 5: + a5 = zero; + case 6: + a6 = zero; + case 7: + a7 = zero; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + *local_buffer++ = *a6++; + *local_buffer++ = *a7++; } } } @@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, // 将B矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; - for (int j = 0; j < n - n_tail; j += NR) { + const int j_length = n - n_tail; + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ asm volatile( "prfm pldl1keep, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1"); #else asm volatile( "pld [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[buffer]]! \n\t" - : [buffer] "+r"(buffer) + "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0", "q1"); #endif // __aarch64__ #else - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; - *buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; #endif // __ARM_NEON } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); +#if __ARM_NEON +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n\t" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + "pld [%[b0]] \n\t" + "vld1.32 {q0, q1}, [%[b0]] \n\t" + "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif // __aarch64__ +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, #if __aarch64__ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; - for (int j = 0; j < n - n_tail; j += NR) { + const int j_length = n - n_tail; + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); asm volatile( "prfm pldl2keep, [%[b0], #64] \n\t" "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1", "v2"); } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < j_length; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2"); + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { - const float *b0; + const int j_length = n - n_tail; for (int j = 0; j < n - n_tail; j += NR) { + float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, j); + const float *b0 = &B(i, j); asm volatile( "prfm pldl2keep, [%[b0], #64] \n\t" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" - : [buffer] "+r"(buffer) + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" + : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1", "v2", "v3"); } } if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { - b0 = &B(i, n - n_tail); - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = *b0++; + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer) { + const int j_length = n - n_tail; +#pragma omp parallel for + for (int j = 0; j < n - n_tail; j += NR) { + float *local_buffer = buffer + j * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j); + asm volatile( + "prfm pldl2keep, [%[b0], #64] \n\t" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "v0", "v1", "v2", "v3"); + } + } + if (n_tail != 0) { + float *local_buffer = buffer + j_length * k; + for (int i = 0; i < k; ++i) { + const float *b0 = &B(i, j_length); + for (int j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; } } } @@ -2394,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(zero); } +// 32位 float 矩阵乘法 +void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif + + int L1 = 32 * 1024; + KC = k; + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + // 补齐 B + NC = (n + NR - 1) / NR * NR; + +#if __aarch64__ + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_8c; + procAddDot = AddDot6x8; +#endif + + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + procPackB(KC, NC, NC % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + // 补齐 A + MC = (m + MR - 1) / MR * MR; + +#if __aarch64__ + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_8c; + procAddDot = AddDot6x8; +#endif + + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + procPackA(MC, KC, MC % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A); + InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B); + InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } + + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} + +void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *new_scale, float *new_bias) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif + + int L1 = 32 * 1024; + KC = k; + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + // 补齐 B + NC = (n + NR - 1) / NR * NR; + +#if __aarch64__ + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_6r; + procPackB = PackMatrixB_omp_8c; + procAddDot = AddDot6x8; +#endif + + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + procPackB(KC, NC, NC % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + // 补齐 A + MC = (m + MR - 1) / MR * MR; + +#if __aarch64__ + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_16c; + procAddDot = AddDot6x16; +#else + procPackA = PackMatrixA_omp_6r; + procPackB = PackMatrixB_8c; + procAddDot = AddDot6x8; +#endif + + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + procPackA(MC, KC, MC % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A); + InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), + ldc, relu, new_scale + i, new_bias + i); + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif + + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B); + InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), + ldc, relu, new_scale, new_bias); + } + } + + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} + void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { #if __ARM_NEON #if __aarch64__ diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 625fce0323580545c1655c1d3c325f995aa054f2..40199faa4c30ec965a3980f44f1dbb6ae7d6799b 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer); void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, float *buffer); +void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); +void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); // 将 B 矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, @@ -58,6 +62,12 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); +void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); +void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, @@ -136,6 +146,16 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias); +// 32位 float 矩阵乘法(openmp 多线程版本) +void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias); + +// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本) +void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *new_scale, float *new_bias); + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 9ac8d79e89b7a577f0a89807dc96c9f368fed6de..381624250af87f4eeff7cf316a2f0f346c399137 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -42,8 +42,13 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; +#ifdef _OPENMP + Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), + N, beta, matrix_out->data(), N, relu, bias); +#else Sgemm(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); +#endif } template <> @@ -70,10 +75,17 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; +#ifdef _OPENMP + SgemmWithBn_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, new_scale->data() + group, + new_bias->data() + group); +#else SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, new_scale->data() + group, new_bias->data() + group); +#endif } } // namespace math