diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 9582c18cbcfb6e502c42ab4195b553bd3b20093b..b165af0bb2a4b3493b2c74e04c43e63d52b0a698 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -458,8 +458,7 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, mc = s_min(m - i, MC); PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA); InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, new_scale + ldc * i + j, - new_bias + ldc * i + j); + &C(i, j), ldc, relu, new_scale + i, new_bias + i); } } @@ -1224,23 +1223,27 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, "mov r5, %[nc1] \n\t" "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" "subs r5, r5, #1 \n\t" "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t" "subs r5, r5, #1 \n\t" @@ -1251,11 +1254,10 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, "blt end_nc2_%= \n\t" "loop_nc2_%=: \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" "subs r6, r6, #1 \n\t" "bge loop_nc2_%= \n\t" @@ -1265,20 +1267,17 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, "beq end_nc3_%= \n\t" "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" "end_nc3_%=: \n\t" + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" - "add %[scale], %[scale], %[step] \n\t" - "add %[bias], %[bias], %[step] \n\t" "add %[C], %[C], %[step] \n\t" "subs %[mc], %[mc], #1 \n\t" @@ -1289,8 +1288,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [scale] "r"(scale), [bias] "r"(bias) - : "memory", "cc", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", - "q5", "q6", "q7", "q10", "q11", "q12", "q13"); + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13"); } // C = A * B, batchnorm(C), relu(C) @@ -1311,25 +1310,29 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, "mov r5, %[nc1] \n\t" "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" "subs r5, r5, #1 \n\t" "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" "vmax.f32 q10, q10, q14 \n\t" "vmax.f32 q11, q11, q14 \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" "vmax.f32 q12, q12, q14 \n\t" "vmax.f32 q13, q13, q14 \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t" @@ -1342,12 +1345,11 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, "blt end_nc2_%= \n\t" "loop_nc2_%=: \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" "subs r6, r6, #1 \n\t" "bge loop_nc2_%= \n\t" @@ -1357,21 +1359,18 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, "beq end_nc3_%= \n\t" "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" "end_nc3_%=: \n\t" + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" - "add %[scale], %[scale], %[step] \n\t" - "add %[bias], %[bias], %[step] \n\t" "add %[C], %[C], %[step] \n\t" "subs %[mc], %[mc], #1 \n\t" @@ -1382,8 +1381,8 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [scale] "r"(scale), [bias] "r"(bias) - : "memory", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", "q5", - "q6", "q7", "q10", "q11", "q12", "q13", "q14"); + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13", "q14"); } // C = A * B