提交 866ab5fc 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #540 from smilejames/develop

update gemm with batchnrom fusion
...@@ -458,8 +458,7 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -458,8 +458,7 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + ldc * i + j, &C(i, j), ldc, relu, new_scale + i, new_bias + i);
new_bias + ldc * i + j);
} }
} }
...@@ -1224,23 +1223,27 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1224,23 +1223,27 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"mov r5, %[nc1] \n\t" "mov r5, %[nc1] \n\t"
"mov r6, %[nc2] \n\t" "mov r6, %[nc2] \n\t"
"vld1.32 {d0}, [%[scale]] \n\t"
"vld1.32 {d1}, [%[bias]] \n\t"
"vdup.32 q1, d0[0] \n\t"
"vdup.32 q2, d1[0] \n\t"
"subs r5, r5, #1 \n\t" "subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t" "blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t" "loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t" "vld1.32 {q3, q4}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t" "vmul.f32 q10, q3, q1 \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t" "vmul.f32 q11, q4, q1 \n\t"
"vmla.f32 q10, q0, q2 \n\t" "vadd.f32 q10, q10, q2 \n\t"
"vmla.f32 q11, q1, q3 \n\t" "vadd.f32 q11, q11, q2 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t" "vld1.32 {q5, q6}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t" "vmul.f32 q12, q5, q1 \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t" "vmul.f32 q13, q6, q1 \n\t"
"vmla.f32 q12, q4, q6 \n\t" "vadd.f32 q12, q12, q2 \n\t"
"vmla.f32 q13, q5, q7 \n\t" "vadd.f32 q13, q13, q2 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs r5, r5, #1 \n\t" "subs r5, r5, #1 \n\t"
...@@ -1251,10 +1254,9 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1251,10 +1254,9 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"blt end_nc2_%= \n\t" "blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t" "loop_nc2_%=: \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vld1.32 {q7}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t" "vmul.f32 q10, q7, q1 \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t" "vadd.f32 q10, q10, q2 \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t" "vst1.32 {q10}, [%[C]]! \n\t"
"subs r6, r6, #1 \n\t" "subs r6, r6, #1 \n\t"
...@@ -1265,20 +1267,17 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1265,20 +1267,17 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
"beq end_nc3_%= \n\t" "beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t" "sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vld1.32 {q8}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t" "vmul.f32 q11, q8, q1 \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t" "vadd.f32 q11, q11, q2 \n\t"
"vmla.f32 q10, q0, q1 \n\t" "vst1.32 {q11}, [%[C]]! \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t" "end_nc3_%=: \n\t"
"add %[scale], %[scale], #4 \n\t"
"add %[bias], %[bias], #4 \n\t"
"add %[c], %[c], %[step1] \n\t" "add %[c], %[c], %[step1] \n\t"
"add %[scale], %[scale], %[step] \n\t"
"add %[bias], %[bias], %[step] \n\t"
"add %[C], %[C], %[step] \n\t" "add %[C], %[C], %[step] \n\t"
"subs %[mc], %[mc], #1 \n\t" "subs %[mc], %[mc], #1 \n\t"
...@@ -1289,8 +1288,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1289,8 +1288,8 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
: [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2),
[nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1),
[scale] "r"(scale), [bias] "r"(bias) [scale] "r"(scale), [bias] "r"(bias)
: "memory", "cc", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q5", "q6", "q7", "q10", "q11", "q12", "q13"); "q8", "q10", "q11", "q12", "q13");
} }
// C = A * B, batchnorm(C), relu(C) // C = A * B, batchnorm(C), relu(C)
...@@ -1311,25 +1310,29 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1311,25 +1310,29 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"mov r5, %[nc1] \n\t" "mov r5, %[nc1] \n\t"
"mov r6, %[nc2] \n\t" "mov r6, %[nc2] \n\t"
"vld1.32 {d0}, [%[scale]] \n\t"
"vld1.32 {d1}, [%[bias]] \n\t"
"vdup.32 q1, d0[0] \n\t"
"vdup.32 q2, d1[0] \n\t"
"subs r5, r5, #1 \n\t" "subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t" "blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t" "loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t" "vld1.32 {q3, q4}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t" "vmul.f32 q10, q3, q1 \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t" "vmul.f32 q11, q4, q1 \n\t"
"vmla.f32 q10, q0, q2 \n\t" "vadd.f32 q10, q10, q2 \n\t"
"vmla.f32 q11, q1, q3 \n\t" "vadd.f32 q11, q11, q2 \n\t"
"vmax.f32 q10, q10, q14 \n\t" "vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t" "vmax.f32 q11, q11, q14 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t" "vld1.32 {q5, q6}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t" "vmul.f32 q12, q5, q1 \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t" "vmul.f32 q13, q6, q1 \n\t"
"vmla.f32 q12, q4, q6 \n\t" "vadd.f32 q12, q12, q2 \n\t"
"vmla.f32 q13, q5, q7 \n\t" "vadd.f32 q13, q13, q2 \n\t"
"vmax.f32 q12, q12, q14 \n\t" "vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t" "vmax.f32 q13, q13, q14 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t"
...@@ -1342,10 +1345,9 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1342,10 +1345,9 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"blt end_nc2_%= \n\t" "blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t" "loop_nc2_%=: \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vld1.32 {q7}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t" "vmul.f32 q10, q7, q1 \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t" "vadd.f32 q10, q10, q2 \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vmax.f32 q10, q10, q14 \n\t" "vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t" "vst1.32 {q10}, [%[C]]! \n\t"
...@@ -1357,21 +1359,18 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1357,21 +1359,18 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
"beq end_nc3_%= \n\t" "beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t" "sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vld1.32 {q8}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t" "vmul.f32 q11, q8, q1 \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t" "vadd.f32 q11, q11, q2 \n\t"
"vmla.f32 q10, q0, q1 \n\t" "vmax.f32 q11, q11, q14 \n\t"
"vmax.f32 q10, q10, q14 \n\t" "vst1.32 {q11}, [%[C]]! \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t" "end_nc3_%=: \n\t"
"add %[scale], %[scale], #4 \n\t"
"add %[bias], %[bias], #4 \n\t"
"add %[c], %[c], %[step1] \n\t" "add %[c], %[c], %[step1] \n\t"
"add %[scale], %[scale], %[step] \n\t"
"add %[bias], %[bias], %[step] \n\t"
"add %[C], %[C], %[step] \n\t" "add %[C], %[C], %[step] \n\t"
"subs %[mc], %[mc], #1 \n\t" "subs %[mc], %[mc], #1 \n\t"
...@@ -1382,8 +1381,8 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, ...@@ -1382,8 +1381,8 @@ void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
: [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2),
[nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1),
[scale] "r"(scale), [bias] "r"(bias) [scale] "r"(scale), [bias] "r"(bias)
: "memory", "r5", "r6", "r7", "r8", "q0", "q1", "q2", "q3", "q4", "q5", : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q6", "q7", "q10", "q11", "q12", "q13", "q14"); "q8", "q10", "q11", "q12", "q13", "q14");
} }
// C = A * B // C = A * B
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册