diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index f9b562fc425eac128c032105d4433133c60e6f3a..4044ee77d02fc168965f177146144ebe84e8a93e 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -459,8 +459,7 @@ void Gemmer::SgemmWithBn(int m, int n, int k, float alpha, const float *A, mc = s_min(m - i, MC); PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA); InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, new_scale + ldc * i + j, - new_bias + ldc * i + j); + &C(i, j), ldc, relu, new_scale + i, new_bias + i); } } @@ -1227,23 +1226,27 @@ void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, "mov r5, %[nc1] \n\t" "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" "subs r5, r5, #1 \n\t" "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t" "subs r5, r5, #1 \n\t" @@ -1254,11 +1257,10 @@ void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, "blt end_nc2_%= \n\t" "loop_nc2_%=: \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" "subs r6, r6, #1 \n\t" "bge loop_nc2_%= \n\t" @@ -1268,20 +1270,17 @@ void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, "beq end_nc3_%= \n\t" "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" "end_nc3_%=: \n\t" + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" - "add %[scale], %[scale], %[step] \n\t" - "add %[bias], %[bias], %[step] \n\t" "add %[C], %[C], %[step] \n\t" "subs %[mc], %[mc], #1 \n\t" @@ -1292,8 +1291,8 @@ void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [scale] "r"(scale), [bias] "r"(bias) - : "memory", "cc", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", - "q5", "q6", "q7", "q10", "q11", "q12", "q13"); + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13"); } // C = A * B, batchnorm(C), relu(C) @@ -1314,25 +1313,29 @@ void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, "mov r5, %[nc1] \n\t" "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" "subs r5, r5, #1 \n\t" "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" "vmax.f32 q10, q10, q14 \n\t" "vmax.f32 q11, q11, q14 \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" "vmax.f32 q12, q12, q14 \n\t" "vmax.f32 q13, q13, q14 \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t" @@ -1345,12 +1348,11 @@ void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, "blt end_nc2_%= \n\t" "loop_nc2_%=: \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" "subs r6, r6, #1 \n\t" "bge loop_nc2_%= \n\t" @@ -1360,21 +1362,18 @@ void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, "beq end_nc3_%= \n\t" "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" "end_nc3_%=: \n\t" + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" "add %[c], %[c], %[step1] \n\t" - "add %[scale], %[scale], %[step] \n\t" - "add %[bias], %[bias], %[step] \n\t" "add %[C], %[C], %[step] \n\t" "subs %[mc], %[mc], #1 \n\t" @@ -1385,8 +1384,8 @@ void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [scale] "r"(scale), [bias] "r"(bias) - : "memory", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", "q5", - "q6", "q7", "q10", "q11", "q12", "q13", "q14"); + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13", "q14"); } // C = A * B