diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc index a146de4ce75ddd511834935f4d16886da5c8b923..703e39449663a66d8076d7b2500a9820c209938c 100644 --- a/mace/ops/arm/fp32/gemv.cc +++ b/mace/ops/arm/fp32/gemv.cc @@ -19,7 +19,11 @@ #include #if !defined(__aarch64__) -#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +float vaddvq_f32(float32x4_t v) { + float32x2_t _sum = vadd_f32(vget_low_f32(v), vget_high_f32(v)); + _sum = vpadd_f32(_sum, _sum); + return vget_lane_f32(_sum, 0); +} #endif // Disable unroll by default, since cache set conflict could be significant @@ -202,8 +206,7 @@ MaceStatus Gemv::Compute(const OpContext *context, : // clobbers "cc", "memory", "r0", "r1", "r2", "r3", "r4", "r5", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", - "d21"); + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19"); lhs_ptr += w_block_count * w_block_size; rhs_ptr += w_block_count * w_block_size; @@ -257,7 +260,7 @@ MaceStatus Gemv::Compute(const OpContext *context, float32x4_t vbias = vdupq_n_f32(0); if (bias) { - vbias = vld1q_f32(bias_data + h_offset); + vbias = vld1q_f32(bias_data + h_start); } vo = vaddq_f32(vo, vbias); vst1q_f32(ret_ptr, vo); @@ -268,24 +271,82 @@ MaceStatus Gemv::Compute(const OpContext *context, for (index_t h = 0; h < h_block_len; ++h) { lhs_ptr = tmp_lhs_ptr + h * lhs_width; rhs_ptr = tmp_rhs_ptr; - float32x4_t vo0 = vdupq_n_f32(0); - float32x4_t vo0n = vdupq_n_f32(0); - for (index_t w = 0; w < w_block_count; ++w) { - float32x4_t vr0 = vld1q_f32(rhs_ptr); - float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); - - float32x4_t vl0 = vld1q_f32(lhs_ptr); - float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); - // may cause some precision error depending on the compute order - vo0 = vmlaq_f32(vo0, vl0, vr0); - vo0n = vmlaq_f32(vo0n, vl0n, vr0n); + float s0 = bias ? bias_data[h_start + h] : 0; - lhs_ptr += 8; - rhs_ptr += 8; - } // w - vo0 = vaddq_f32(vo0, vo0n); - float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_start + h] : 0); + if (w_block_count) { +#if not defined(__aarch64__) + index_t r_w_block_count = w_block_count; + float32x4_t vo = vdupq_n_f32(0.f); + + asm volatile( + "mov r0, #0\n" + "vdup.f32 q2, r0\n" + "vdup.f32 q3, r0\n" + + // prelogue + "vld1.f32 {d16-d17}, [%[rhs_ptr]]!\n" + "vld1.f32 {d18-d19}, [%[rhs_ptr]]!\n" + + "subs %[r_w_block_count], #1\n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + + "beq 1f\n" + + "0: \n" + "vmla.f32 q2, q0, q8\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vld1.f32 {d16-d17}, [%[rhs_ptr]]!\n" + + "subs %[r_w_block_count], #1\n" + + "vmla.f32 q3, q1, q9\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vld1.f32 {d18-d19}, [%[rhs_ptr]]!\n" + + "bne 0b\n" + + // prologue + "1:\n" + "vmla.f32 q2, q0, q8\n" + "vmla.f32 q3, q1, q9\n" + "vaddq.f32 %q[vo], q2, q3\n" + : // outputs + [r_w_block_count] "+r"(r_w_block_count), + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [vo] "+w"(vo) + : // inputs + : // clobbers + "cc", "memory", "r0", + "d0", "d1", "d2", "d3", // lhs + "d4", "d5", "d6", "d7", // output + "d16", "d17", "d18", "d19" // rhs + ); + + s0 += vaddvq_f32(vo); +#else + float32x4_t vo0 = vdupq_n_f32(0); + float32x4_t vo0n = vdupq_n_f32(0); + for (index_t w = 0; w < w_block_count; ++w) { + float32x4_t vr0 = vld1q_f32(rhs_ptr); + float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); + + float32x4_t vl0 = vld1q_f32(lhs_ptr); + float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); + + vo0 = vmlaq_f32(vo0, vl0, vr0); + vo0n = vmlaq_f32(vo0n, vl0n, vr0n); + + lhs_ptr += 8; + rhs_ptr += 8; + } // w + vo0 = vaddq_f32(vo0, vo0n); + s0 += vaddvq_f32(vo0); +#endif // __aarch64__ + } // if for (index_t w = 0; w < w_remain; ++w) { s0 += lhs_ptr[0] * rhs_ptr[0]; ++lhs_ptr; @@ -294,6 +355,7 @@ MaceStatus Gemv::Compute(const OpContext *context, ret_ptr[h] = s0; } // h + #ifdef MACE_GEMV_UNROLL } // if #endif // MACE_GEMV_UNROLL