diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h index cc9b9f453af55fb63f1494a0525087b6b17fed7d..813205dab902468eac20247dc18055a315cdb81e 100644 --- a/src/operators/math/gemm/gemm_kernel.h +++ b/src/operators/math/gemm/gemm_kernel.h @@ -407,14 +407,10 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, _b = vandq_f32_u32(_b, vmask); _sum0 = vmlaq_f32(_sum0, _r0, _b); } - _sum0 = vpaddq_f32(_sum0, _sum0); - _sum0 = vmulq_f32(_sum0, _valpha); - if (beta != 0.f) { - float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta); - _sum0 = vpaddq_f32(_sum0, _sum2); - } - // restore - *output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1); + float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0)); + float32x2_t _sss2 = vpadd_f32(_ss, _ss); + *output = + vget_lane_f32(_sss2, 0) * vgetq_lane_f32(_valpha, 0) + beta * (*output); } } @@ -536,7 +532,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, } else { // beta != 0.f float32x4_t _vbeta = vdupq_n_f32(beta); #pragma omp parallel for - for (int m = 0; m < M; m += 4) { + for (int m = 0; m < M - 3; m += 4) { float32x4_t _sum0 = vld1q_f32(buf_c + m); for (int tid = 1; tid < threads_num; ++tid) { _sum0 += vld1q_f32(buf_c + tid * M + m); @@ -545,7 +541,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, vst1q_f32(C + m, _sum0 * _valpha + _vbeta * _vc); } - for (int m = (M & 0xfffffffc); m < M - 3; ++m) { + for (int m = (M & 0xfffffffc); m < M; ++m) { float _sum0 = *(buf_c + m); for (int tid = 1; tid < threads_num; ++tid) { _sum0 += *(buf_c + tid * M + m);