提交 0921c33f 编写于 作者: L lixian

fix fp32 kernel on arm32 and ReluFp32

上级 2da29bce
...@@ -17,16 +17,18 @@ ...@@ -17,16 +17,18 @@
IndirectGemmFp32_8x4: IndirectGemmFp32_8x4:
.macro INIT_BIAS .macro INIT_BIAS
veor q10, q10, q10 veor q8, q8, q8
cmp r3, #0 cmp r3, #0
beq InitBias beq InitBias
vld1.32 q10, [r3] vld1.32 {q8}, [r3]
InitBias: InitBias:
vmov q11, q10 vmov q9, q8
vmov q12, q10 vmov q10, q8
vmov q13, q10 vmov q11, q8
vmov q14, q10 vmov q12, q8
vmov q15, q10 vmov q13, q8
vmov q14, q8
vmov q15, q8
.endm .endm
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
...@@ -36,7 +38,7 @@ IndirectGemmFp32_8x4: ...@@ -36,7 +38,7 @@ IndirectGemmFp32_8x4:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r10, r11, lr} push {r4-r8, r10, r11, lr}
vpush {q4-q7} vpush {q4-q7}
add sp, sp, #160 add sp, sp, #96
ldr r4, [sp] ldr r4, [sp]
ldr r5, [sp, #4] ldr r5, [sp, #4]
...@@ -66,8 +68,8 @@ IndirectGemmFp32_8x4: ...@@ -66,8 +68,8 @@ IndirectGemmFp32_8x4:
// load weight // load weight
vld1.32 {q4, q5}, [r2]! vld1.32 {q4, q5}, [r2]!
// step for output 1-2 // step for output 1-2
vmul.f32 q8, q4, d0[0] vmla.f32 q8, q4, d0[0]
vmul.f32 q9, q4, d2[0] vmla.f32 q9, q4, d2[0]
vmla.f32 q8, q5, d0[1] vmla.f32 q8, q5, d0[1]
vmla.f32 q9, q5, d2[1] vmla.f32 q9, q5, d2[1]
vld1.32 {q6, q7}, [r2]! vld1.32 {q6, q7}, [r2]!
...@@ -158,31 +160,31 @@ IndirectGemmFp32_8x4: ...@@ -158,31 +160,31 @@ IndirectGemmFp32_8x4:
bne Relu bne Relu
b WriteStart b WriteStart
Relu6: Relu6:
vmov.i32 q14, #6 vmov.i32 q7, #6
vcvt.f32.s32 q14, q14 vcvt.f32.s32 q7, q7
vmin.f32 q0, q0, q14 vmin.f32 q8, q8, q7
vmin.f32 q1, q1, q14 vmin.f32 q9, q9, q7
vmin.f32 q2, q2, q14 vmin.f32 q10, q10, q7
vmin.f32 q3, q3, q14 vmin.f32 q11, q11, q7
vmin.f32 q4, q4, q14 vmin.f32 q12, q12, q7
vmin.f32 q5, q5, q14 vmin.f32 q13, q13, q7
vmin.f32 q6, q6, q14 vmin.f32 q14, q14, q7
vmin.f32 q7, q15, q14 vmin.f32 q15, q15, q7
Relu: Relu:
veor q7, q7, q7 veor q7, q7, q7
vmax.f32 q0, q8, q7 vmax.f32 q8, q8, q7
vmax.f32 q1, q9, q7 vmax.f32 q9, q9, q7
vmax.f32 q2, q10, q7 vmax.f32 q10, q10, q7
vmax.f32 q3, q11, q7 vmax.f32 q11, q11, q7
vmax.f32 q4, q12, q7 vmax.f32 q12, q12, q7
vmax.f32 q5, q13, q7 vmax.f32 q13, q13, q7
vmax.f32 q6, q14, q7 vmax.f32 q14, q14, q7
vmax.f32 q15, q15, q7 vmax.f32 q15, q15, q7
WriteStart: WriteStart:
ldr r10, [sp, #20] ldr r10, [sp, #20]
cmp r10, #0 cmp r10, #0
bne WriteC4 bne Write4
cmp r6, #1 cmp r6, #1
beq Write1 beq Write1
cmp r6, #2 cmp r6, #2
...@@ -191,98 +193,91 @@ IndirectGemmFp32_8x4: ...@@ -191,98 +193,91 @@ IndirectGemmFp32_8x4:
beq Write3 beq Write3
b Write4 b Write4
Write1: Write1:
vst1.32 d0[0], [r11] vst1.32 d16[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d2[0], [r11] vst1.32 d18[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d4[0], [r11] vst1.32 d20[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d6[0], [r11] vst1.32 d22[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d8[0], [r11] vst1.32 d24[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d10[0], [r11] vst1.32 d26[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d12[0], [r11] vst1.32 d28[0], [r11]
add r11, r11, r7
vst1.32 d30[0], [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d14[0], [r11]
add r0, r0, #4 add r0, r0, #4
b WriteEnd b WriteEnd
Write2: Write2:
vst1.32 d0, [r11] vst1.32 d16, [r11]
add r11, r11, r7
vst1.32 d18, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d2, [r11] vst1.32 d20, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d4, [r11] vst1.32 d22, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d6, [r11] vst1.32 d24, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d8, [r11] vst1.32 d26, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d10, [r11] vst1.32 d28, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d12, [r11] vst1.32 d30, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d14, [r11]
add r0, r0, #8 add r0, r0, #8
b WriteEnd b WriteEnd
Write3: Write3:
add r12, r11, #8 add lr, r11, #8
vst1.32 d0, [r11] vst1.32 d16, [r11]
add r11, r11, r7
vst1.32 d17[0], [lr]
add lr, lr, r7
vst1.32 d18, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d1[0], [r12] vst1.32 d19[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d2, [r11] vst1.32 d20, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d3[0], [r12] vst1.32 d21[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d4, [r11] vst1.32 d22, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d5[0], [r12] vst1.32 d23[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d6, [r11] vst1.32 d24, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d7[0], [r12] vst1.32 d25[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d8, [r11] vst1.32 d26, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d9[0], [r12] vst1.32 d27[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d10, [r11] vst1.32 d28, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d11[0], [r12] vst1.32 d29[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d12, [r11] vst1.32 d30, [r11]
add r11, r11, r7 add r11, r11, r7
vst1.32 d13[0], [r12] vst1.32 d31[0], [lr]
add r12, r12, r7 add lr, lr, r7
vst1.32 d14, [r11]
vst1.32 d15[0], [r12]
add r0, r0, #12 add r0, r0, #12
b WriteEnd b WriteEnd
WriteC4:
vst1.32 q0, [r11], r7
vst1.32 q1, [r11], r7
vst1.32 q2, [r11], r7
vst1.32 q3, [r11], r7
vst1.32 q4, [r11], r7
vst1.32 q5, [r11], r7
vst1.32 q6, [r11], r7
vst1.32 q7, [r11]
add r0, r0, #16
b WriteEnd
Write4: Write4:
// prefetching is not prefered while writing results in spite of cache missings // prefetching is not prefered while writing results in spite of cache missings
// you could try prfm pstl2vst1.32m // you could try pld
// there are almost no benefits observed though // there are almost no benefits observed though
vst1.32 q0, [r11], r7 vst1.32 {q8}, [r11], r7
vst1.32 q1, [r11], r7 vst1.32 {q9}, [r11], r7
vst1.32 q2, [r11], r7 vst1.32 {q10}, [r11], r7
vst1.32 q3, [r11], r7 vst1.32 {q11}, [r11], r7
vst1.32 q4, [r11], r7 vst1.32 {q12}, [r11], r7
vst1.32 q5, [r11], r7 vst1.32 {q13}, [r11], r7
vst1.32 q6, [r11], r7 vst1.32 {q14}, [r11], r7
vst1.32 q7, [r11] vst1.32 {q15}, [r11], r7
add r0, r0, #16 add r0, r0, #16
WriteEnd: WriteEnd:
...@@ -290,14 +285,17 @@ IndirectGemmFp32_8x4: ...@@ -290,14 +285,17 @@ IndirectGemmFp32_8x4:
subs r8, r8, #1 subs r8, r8, #1
bne LoopKsize bne LoopKsize
subs r6, r6, #4 cmp r6, #4
ble LoopOcEnd
sub r6, r6, #4
cmp r3, #0 cmp r3, #0
beq NoStepFowrard beq NoStepFowrard
add r3, r3, #16 add r3, r3, #16
NoStepFowrard: NoStepFowrard:
bgt LoopOc b LoopOc
add sp, sp, #160 LoopOcEnd:
sub sp, sp, #96
vpop {q4-q7} vpop {q4-q7}
pop {r4-r8, r10, r11, pc} pop {r4-r8, r10, r11, pc}
#endif #endif
......
...@@ -31,7 +31,7 @@ IndirectGemmInt8_2x4: ...@@ -31,7 +31,7 @@ IndirectGemmInt8_2x4:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r10, r11, lr} push {r4-r8, r10, r11, lr}
vpush {q4-q7} vpush {q4-q7}
add sp, sp, #160 add sp, sp, #96
ldr r4, [sp] ldr r4, [sp]
ldr r5, [sp, #4] ldr r5, [sp, #4]
...@@ -226,14 +226,17 @@ IndirectGemmInt8_2x4: ...@@ -226,14 +226,17 @@ IndirectGemmInt8_2x4:
subs r8, r8, #1 subs r8, r8, #1
bne LoopKsize bne LoopKsize
subs r6, r6, #4 cmp r6, #4
ble LoopOcEnd
sub r6, r6, #4
cmp r3, #0 cmp r3, #0
beq NoStepFowrard beq NoStepFowrard
add r3, r3, #16 add r3, r3, #16
NoStepFowrard: NoStepFowrard:
bgt LoopOc b LoopOc
add sp, sp, #160 LoopOcEnd:
sub sp, sp, #96
vpop {q4-q7} vpop {q4-q7}
pop {r4-r8, r10, r11, pc} pop {r4-r8, r10, r11, pc}
#endif #endif
......
...@@ -159,6 +159,7 @@ void ReluFp32(float *data, int ele_num) { ...@@ -159,6 +159,7 @@ void ReluFp32(float *data, int ele_num) {
float32x4_t relu_data = vld1q_f32(data + index); float32x4_t relu_data = vld1q_f32(data + index);
float32x4_t zero_data = vdupq_n_f32(0); float32x4_t zero_data = vdupq_n_f32(0);
relu_data = vmaxq_f32(relu_data, zero_data); relu_data = vmaxq_f32(relu_data, zero_data);
vst1q_f32(data + index, relu_data);
#else #else
data[index] = data[index] < 0 ? 0 : data[index]; data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
...@@ -181,6 +182,7 @@ void Relu6Fp32(float *data, int ele_num) { ...@@ -181,6 +182,7 @@ void Relu6Fp32(float *data, int ele_num) {
float32x4_t six_data = vdupq_n_f32(6); float32x4_t six_data = vdupq_n_f32(6);
relu6_data = vmaxq_f32(relu6_data, zero_data); relu6_data = vmaxq_f32(relu6_data, zero_data);
relu6_data = vminq_f32(relu6_data, six_data); relu6_data = vminq_f32(relu6_data, six_data);
vst1q_f32(data + index, relu6_data);
#else #else
data[index] = data[index] < 0 ? 0 : data[index]; data[index] = data[index] < 0 ? 0 : data[index];
data[index] = data[index] > 6 ? 6 : data[index]; data[index] = data[index] > 6 ? 6 : data[index];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册