diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 95f28ab9133a6ac668be11c3ba22f6c5b6db61e3..81261dc49414d72a799ca2a83f1c298895a298bd 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -473,7 +473,7 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, int ldb, float beta, float *C, int ldc, int mc, int nc, bool relu = false) { - int kc1 = k / 2, kc2 = k % 2; + int kc1 = k / 4, kc2 = k % 4; int bytes_ldc = 4 * ldc; int flag_alpha = (alpha == 1.0) ? 1 : 2; int flag_beta; @@ -484,16 +484,29 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, } else { flag_beta = 2; } - asm volatile( + "pld [%[a]] \n\t" + "pld [%[b]] \n\t" "vmov.f32 q10, #0.0 \n\t" "vmov.f32 q11, #0.0 \n\t" "vmov.f32 q12, #0.0 \n\t" "vmov.f32 q13, #0.0 \n\t" - "vmov.f32 q14, #0.0 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" "blt end_kc1_%= \n\t" "loop_kc1_%=: \n\t" + "pld [%[a], #64] \n\t" + "pld [%[b], #64] \n\t" + "vld1.32 {q0, q1}, [%[a]]! \n\t" + "vld1.32 {q2, q3}, [%[b]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q2, d0[1] \n\t" + "vmla.f32 q12, q2, d1[0] \n\t" + "vmla.f32 q13, q2, d1[1] \n\t" + "vmla.f32 q10, q3, d2[0] \n\t" + "vmla.f32 q11, q3, d2[1] \n\t" + "vmla.f32 q12, q3, d3[0] \n\t" + "vmla.f32 q13, q3, d3[1] \n\t" "vld1.32 {q0, q1}, [%[a]]! \n\t" "vld1.32 {q2, q3}, [%[b]]! \n\t" "vmla.f32 q10, q2, d0[0] \n\t" @@ -536,7 +549,7 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, "vmul.f32 q13, q13, d8[0] \n\t" "beta_%=: \n\t" - "cmp %[flag_beta], #0 \n\t" + "cmp %[flag_beta], #0 \n\t" "beq memory_%= \n\t" "mov r4, %[C] \n\t" @@ -545,7 +558,7 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, "vld1.32 {q1}, [r4], r6 \n\t" "vld1.32 {q2}, [r4], r6 \n\t" "vld1.32 {q3}, [r4] \n\t" - "cmp %[flag_beta], #1 \n\t" + "cmp %[flag_beta], #1 \n\t" "beq beta_eq1_%= \n\t" "bne beta_ne1_%= \n\t" @@ -569,7 +582,6 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, "vmax.f32 q13, q13, q14 \n\t" "mov r5, %[C] \n\t" "mov r6, %[bytes_ldc]\n\t" - "vst1.32 {q10}, [r5], r6 \n\t" "vst1.32 {q11}, [r5], r6 \n\t" "vst1.32 {q12}, [r5], r6 \n\t" @@ -585,8 +597,7 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, [kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha), [beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc), [flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta) - : "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13", - "q14"); + : "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13"); if (mc != MR || nc != NR) { int i, j;