diff --git a/src/memory/t_malloc.cpp b/src/memory/t_malloc.cpp index 280391da5ac0b5c7bdbbbbe8df6772377ca075c5..92cd9ac0364ef12a14662d986419cc2691971a87 100644 --- a/src/memory/t_malloc.cpp +++ b/src/memory/t_malloc.cpp @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle_mobile { namespace memory { -const int MALLOC_ALIGN = 16; +const int MALLOC_ALIGN = 64; void Copy(void *dst, const void *src, size_t num) { std::memcpy(dst, src, num); diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 81261dc49414d72a799ca2a83f1c298895a298bd..7c42d6dce781dc35ccf8851db87d11c51adf6273 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "operators/math/gemm.h" +#include "common/log.h" +#include "memory/t_malloc.h" #ifndef X86 #include #endif @@ -757,6 +759,10 @@ void sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { int i, j, p, mc, nc, kc; float beta_; + if (m == 1) { + VectorKernel(1, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return; + } for (j = 0; j < n; j += NC) { nc = s_min(n - j, NC); for (p = 0; p < k; p += KC) { @@ -803,6 +809,223 @@ void sgemm_relu(int m, int n, int k, float alpha, const float *A, int lda, } } +void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc) { + float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); + + const float *a0, *b0, *b1, *b2, *b3; + float *c0, *C0; + + int volatile kc1 = k / 4; + int volatile kc2 = k % 4; + int volatile nc1 = n / 16; + int _nc1 = n % 16; + int volatile nc2 = _nc1 / 4; + int volatile nc3 = _nc1 % 4; + // DLOG << "GEMM VECTOR kc1 = " << kc1 << ", kc2 = " << kc2; + // DLOG << "GEMM VECTOR nc1 = " << nc1 << ", nc2 = " << nc2 << ", nc3 = " << + // nc3; + for (int i = 0; i < kc1; i++) { + a0 = A + i * 4; + b0 = B + i * 4 * ldb; + b1 = b0 + ldb; + b2 = b1 + ldb; + b3 = b2 + ldb; + c0 = bufferC; + asm volatile( + "pld [%[a0], #16] \n\t" + "vld1.32 {q0}, [%[a0]] \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "cmp %[i], #0 \n\t" + "beq i_eq0_%= \n\t" + "bne i_ne0_%= \n\t" + + "i_eq0_%=: \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "b gemm_nc1_%= \n\t" + + "i_ne0_%=: \n\t" + "pld [%[c0], #64] \n\t" + "vld1.32 {q10, q11}, [%[c0]]! \n\t" + "vld1.32 {q12, q13}, [%[c0]] \n\t" + "sub %[c0], %[c0], #32 \n\t" + + "gemm_nc1_%=: \n\t" + "pld [%[b0], #64] \n\t" + "vld1.32 {q2, q3}, [%[b0]]! \n\t" + "vld1.32 {q4, q5}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q3, d0[0] \n\t" + "vmla.f32 q12, q4, d0[0] \n\t" + "vmla.f32 q13, q5, d0[0] \n\t" + + "pld [%[b1], #64] \n\t" + "vld1.32 {q2, q3}, [%[b1]]! \n\t" + "vld1.32 {q4, q5}, [%[b1]]! \n\t" + "vmla.f32 q10, q2, d0[1] \n\t" + "vmla.f32 q11, q3, d0[1] \n\t" + "vmla.f32 q12, q4, d0[1] \n\t" + "vmla.f32 q13, q5, d0[1] \n\t" + + "pld [%[b2], #64] \n\t" + "vld1.32 {q2, q3}, [%[b2]]! \n\t" + "vld1.32 {q4, q5}, [%[b2]]! \n\t" + "vmla.f32 q10, q2, d1[0] \n\t" + "vmla.f32 q11, q3, d1[0] \n\t" + "vmla.f32 q12, q4, d1[0] \n\t" + "vmla.f32 q13, q5, d1[0] \n\t" + + "pld [%[b3], #64] \n\t" + "vld1.32 {q2, q3}, [%[b3]]! \n\t" + "vld1.32 {q4, q5}, [%[b3]]! \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q4, d1[1] \n\t" + "vmla.f32 q13, q5, d1[1] \n\t" + + "vst1.32 {q10, q11}, [%[c0]]! \n\t" + "vst1.32 {q12, q13}, [%[c0]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "cmp %[i], #0 \n\t" + "beq ii_eq0_%= \n\t" + "bne ii_ne0_%= \n\t" + + "ii_eq0_%=: \n\t" + "vmov.f32 q10, #0.0 \n\t" + "b gemm_nc2_%= \n\t" + + "ii_ne0_%=: \n\t" + "pld [%[c0], #16] \n\t" + "vld1.32 {q10}, [%[c0]] \n\t" + + "gemm_nc2_%=: \n\t" + "pld [%[b0], #16] \n\t" + "vld1.32 {q2}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + + "pld [%[b1], #16] \n\t" + "vld1.32 {q3}, [%[b1]]! \n\t" + "vmla.f32 q10, q3, d0[1] \n\t" + + "pld [%[b2], #16] \n\t" + "vld1.32 {q4}, [%[b2]]! \n\t" + "vmla.f32 q10, q4, d1[0] \n\t" + + "pld [%[b3], #16] \n\t" + "vld1.32 {q5}, [%[b3]]! \n\t" + "vmla.f32 q10, q5, d1[1] \n\t" + + "vst1.32 {q10}, [%[c0]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), + [c0] "+r"(c0) + : [a0] "r"(a0), [i] "r"(i), [nc1] "r"(nc1), [nc2] "r"(nc2) + : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); + + for (int j = 0; j < nc3; j++) { + if (i == 0) { + *c0 = (*a0) * (*b0++); + } else { + *c0 += (*a0) * (*b0++); + } + *c0 += (*(a0 + 1)) * (*b1++); + *c0 += (*(a0 + 2)) * (*b2++); + *c0 += (*(a0 + 3)) * (*b3++); + c0++; + } + } + + for (int i = 0; i < kc2; ++i) { + a0 = A + 4 * kc1 + i; + b0 = B + (4 * kc1 + i) * ldb; + c0 = bufferC; + asm volatile( + "pld [%[a0], #16] \n\t" + "vld1.32 {d0}, [%[a0]] \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "pld [%[c0], #64] \n\t" + "vld1.32 {q10, q11}, [%[c0]]! \n\t" + "vld1.32 {q12, q13}, [%[c0]] \n\t" + "sub %[c0], %[c0], #32 \n\t" + + "gemm_nc1_%=: \n\t" + "pld [%[b0], #64] \n\t" + "vld1.32 {q2, q3}, [%[b0]]! \n\t" + "vld1.32 {q4, q5}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q3, d0[0] \n\t" + "vmla.f32 q12, q4, d0[0] \n\t" + "vmla.f32 q13, q5, d0[0] \n\t" + + "vst1.32 {q10, q11}, [%[c0]]! \n\t" + "vst1.32 {q12, q13}, [%[c0]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "pld [%[c0], #16] \n\t" + "vld1.32 {q10}, [%[c0]] \n\t" + + "gemm_nc2_%=: \n\t" + "vld1.32 {q2}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + + "vst1.32 {q10}, [%[c0]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), + [c0] "+r"(c0) + : [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2) + : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); + + for (int j = 0; j < nc3; j++) { + *c0 += (*a0) * (*b0++); + c0++; + } + } + + c0 = bufferC; + C0 = C; + for (int i = 0; i < n; i++) { + if (beta == 1.0) { + *C0++ += *c0++; + } else { + *C0++ = *c0++; + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 00285aed94613ac7666c6c68df7b3208b09a777a..b5351dd1e8f9bc93f9e77cfe4adf572c890c1d37 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -53,6 +53,10 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int first_time); +// 向量矩阵乘法 (M = 1) +void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc); + // 计算一个更小的 4 * 4 的 C 矩阵分块 void AddDot4x4(int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int mc, int nc);