diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 9f0a18f04f9f247cc06ccf73a36b574cb19d92ad..04a76465b834bad7553870e0a1eb0e5d47cdd71f 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -2155,34 +2155,32 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, if (bias1 == nullptr) { for (int i = 0; i < mc; ++i) { for (int j = 0; j < nc; ++j) { - float r = *c + *bias; + float r = c[i * NC + j] + bias[i]; if (r < 0) { - r = *p; + r *= p[i]; } - c++; + C[i * ldc + j] = r; } - bias++; - p++; } } else { for (int i = 0; i < mc; ++i) { for (int j = 0; j < nc; ++j) { - float r = *c + *bias; - r += *bias1; + float r = c[i * NC + j] + bias[i]; + r += bias1[i * ldc + j]; if (r < 0) { - r *= *p; + r *= p[i]; } - c++; - bias1++; + C[i * ldc + j] = r; } - bias++; - p++; } } return; } - int nc1 = nc / 8; + int nc1 = nc / 16; + int _nc1 = nc % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); int step = 4 * (ldc - nc); int step1 = 4 * (NC - nc); @@ -2194,6 +2192,7 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, "loop_mc_%=: \n\t" "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" "vld1.32 {d0}, [%[bias]] \n\t" "vld1.32 {d1}, [%[p]] \n\t" "vdup.32 q1, d0[0] \n\t" @@ -2205,20 +2204,64 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, "pld [%[c], #32] \n\t" "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vld1.32 {q9, q10}, [%[c]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q9, q9, q1 \n\t" + "vadd.f32 q10, q10, q1 \n\t" + "vmax.f32 q5, q3, q14 \n\t" "vmin.f32 q7, q3, q14 \n\t" "vmax.f32 q6, q4, q14 \n\t" "vmin.f32 q8, q4, q14 \n\t" + + "vmax.f32 q11, q9, q14 \n\t" + "vmin.f32 q13, q9, q14 \n\t" + "vmax.f32 q12, q10, q14 \n\t" + "vmin.f32 q15, q10, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" "vmla.f32 q6, q8, q2 \n\t" + "vmla.f32 q11, q13, q2 \n\t" + "vmla.f32 q12, q15, q2 \n\t" + "vst1.32 {q5, q6}, [%[C]]! \n\t" + "vst1.32 {q11, q12}, [%[C]]! \n\t" "subs r5, r5, #1 \n\t" "bge loop_nc1_%= \n\t" "end_nc1_%=: \n\t" + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q3}, [%[c]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q4}, [%[c]]! \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q6}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + "add %[p], %[p], #4 \n\t" "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" @@ -2229,10 +2272,11 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, "end_mc_%=: \n\t" : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1), [p] "r"(p), [bias] "r"(bias), - [bias1] "r"(bias1) - : "memory", "r5", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), + [bias] "r"(bias), [bias1] "r"(bias1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); } else { asm volatile( "vmov.f32 q14, #0.0 \n\t" @@ -2241,6 +2285,7 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, "loop_mc_%=: \n\t" "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" "vld1.32 {d0}, [%[bias]] \n\t" "vld1.32 {d1}, [%[p]] \n\t" "vdup.32 q1, d0[0] \n\t" @@ -2266,25 +2311,74 @@ void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p, "vmla.f32 q6, q8, q2 \n\t" "vst1.32 {q5, q6}, [%[C]]! \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vld1.32 {q9, q10}, [%[bias1]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q3, q3, q9 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q5, q6}, [%[C]]! \n\t" + "subs r5, r5, #1 \n\t" "bge loop_nc1_%= \n\t" "end_nc1_%=: \n\t" + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q3}, [%[c]]! \n\t" + "vld1.32 {q9}, [%[bias1]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q3, q3, q9 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "sub %[bias1], %[bias1], %[nc3] \n\t" + + "vld1.32 {q4}, [%[c]]! \n\t" + "vld1.32 {q10}, [%[bias1]]! \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q6}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + "add %[p], %[p], #4 \n\t" "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" "add %[C], %[C], %[step] \n\t" + "add %[bias1], %[bias1], %[step] \n\t" "subs %[mc], %[mc], #1 \n\t" "bge loop_mc_%= \n\t" "end_mc_%=: \n\t" : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1), [p] "r"(p), [bias] "r"(bias), - [bias1] "r"(bias1) - : "memory", "r5", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10"); + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), + [bias] "r"(bias), [bias1] "r"(bias1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10"); } } @@ -3331,7 +3425,7 @@ void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, int max_threads = 1; #endif - int L1 = 32 * 1024; + int L1 = 8 * 1024; KC = k; if (m > n) { // 对 A 分块