提交 f2f05c0d 编写于 作者: 李滨

Merge branch 'opt_gemm' into 'master'

Minor improvement of gemv asm

See merge request !984
......@@ -19,7 +19,11 @@
#include <algorithm>
#if !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
float vaddvq_f32(float32x4_t v) {
float32x2_t _sum = vadd_f32(vget_low_f32(v), vget_high_f32(v));
_sum = vpadd_f32(_sum, _sum);
return vget_lane_f32(_sum, 0);
}
#endif
// Disable unroll by default, since cache set conflict could be significant
......@@ -202,8 +206,7 @@ MaceStatus Gemv::Compute(const OpContext *context,
: // clobbers
"cc", "memory", "r0", "r1", "r2", "r3", "r4", "r5",
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21");
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19");
lhs_ptr += w_block_count * w_block_size;
rhs_ptr += w_block_count * w_block_size;
......@@ -257,7 +260,7 @@ MaceStatus Gemv::Compute(const OpContext *context,
float32x4_t vbias = vdupq_n_f32(0);
if (bias) {
vbias = vld1q_f32(bias_data + h_offset);
vbias = vld1q_f32(bias_data + h_start);
}
vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo);
......@@ -268,24 +271,82 @@ MaceStatus Gemv::Compute(const OpContext *context,
for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr;
float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo0n = vdupq_n_f32(0);
for (index_t w = 0; w < w_block_count; ++w) {
float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
// may cause some precision error depending on the compute order
vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0n = vmlaq_f32(vo0n, vl0n, vr0n);
float s0 = bias ? bias_data[h_start + h] : 0;
lhs_ptr += 8;
rhs_ptr += 8;
} // w
vo0 = vaddq_f32(vo0, vo0n);
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_start + h] : 0);
if (w_block_count) {
#if not defined(__aarch64__)
index_t r_w_block_count = w_block_count;
float32x4_t vo = vdupq_n_f32(0.f);
asm volatile(
"mov r0, #0\n"
"vdup.f32 q2, r0\n"
"vdup.f32 q3, r0\n"
// prelogue
"vld1.f32 {d16-d17}, [%[rhs_ptr]]!\n"
"vld1.f32 {d18-d19}, [%[rhs_ptr]]!\n"
"subs %[r_w_block_count], #1\n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n"
"beq 1f\n"
"0: \n"
"vmla.f32 q2, q0, q8\n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n"
"vld1.f32 {d16-d17}, [%[rhs_ptr]]!\n"
"subs %[r_w_block_count], #1\n"
"vmla.f32 q3, q1, q9\n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n"
"vld1.f32 {d18-d19}, [%[rhs_ptr]]!\n"
"bne 0b\n"
// prologue
"1:\n"
"vmla.f32 q2, q0, q8\n"
"vmla.f32 q3, q1, q9\n"
"vaddq.f32 %q[vo], q2, q3\n"
: // outputs
[r_w_block_count] "+r"(r_w_block_count),
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[vo] "+w"(vo)
: // inputs
: // clobbers
"cc", "memory", "r0",
"d0", "d1", "d2", "d3", // lhs
"d4", "d5", "d6", "d7", // output
"d16", "d17", "d18", "d19" // rhs
);
s0 += vaddvq_f32(vo);
#else
float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo0n = vdupq_n_f32(0);
for (index_t w = 0; w < w_block_count; ++w) {
float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0n = vmlaq_f32(vo0n, vl0n, vr0n);
lhs_ptr += 8;
rhs_ptr += 8;
} // w
vo0 = vaddq_f32(vo0, vo0n);
s0 += vaddvq_f32(vo0);
#endif // __aarch64__
} // if
for (index_t w = 0; w < w_remain; ++w) {
s0 += lhs_ptr[0] * rhs_ptr[0];
++lhs_ptr;
......@@ -294,6 +355,7 @@ MaceStatus Gemv::Compute(const OpContext *context,
ret_ptr[h] = s0;
} // h
#ifdef MACE_GEMV_UNROLL
} // if
#endif // MACE_GEMV_UNROLL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册