提交 f6ac7dd2 编写于 作者: Z zp7 提交者: Yanzhan Yang

fix gemm function sgemv_mx1 (#1743)

上级 dd381236
...@@ -407,14 +407,10 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, ...@@ -407,14 +407,10 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
_b = vandq_f32_u32(_b, vmask); _b = vandq_f32_u32(_b, vmask);
_sum0 = vmlaq_f32(_sum0, _r0, _b); _sum0 = vmlaq_f32(_sum0, _r0, _b);
} }
_sum0 = vpaddq_f32(_sum0, _sum0); float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
_sum0 = vmulq_f32(_sum0, _valpha); float32x2_t _sss2 = vpadd_f32(_ss, _ss);
if (beta != 0.f) { *output =
float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta); vget_lane_f32(_sss2, 0) * vgetq_lane_f32(_valpha, 0) + beta * (*output);
_sum0 = vpaddq_f32(_sum0, _sum2);
}
// restore
*output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1);
} }
} }
...@@ -536,7 +532,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, ...@@ -536,7 +532,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
} else { // beta != 0.f } else { // beta != 0.f
float32x4_t _vbeta = vdupq_n_f32(beta); float32x4_t _vbeta = vdupq_n_f32(beta);
#pragma omp parallel for #pragma omp parallel for
for (int m = 0; m < M; m += 4) { for (int m = 0; m < M - 3; m += 4) {
float32x4_t _sum0 = vld1q_f32(buf_c + m); float32x4_t _sum0 = vld1q_f32(buf_c + m);
for (int tid = 1; tid < threads_num; ++tid) { for (int tid = 1; tid < threads_num; ++tid) {
_sum0 += vld1q_f32(buf_c + tid * M + m); _sum0 += vld1q_f32(buf_c + tid * M + m);
...@@ -545,7 +541,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha, ...@@ -545,7 +541,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
vst1q_f32(C + m, _sum0 * _valpha + _vbeta * _vc); vst1q_f32(C + m, _sum0 * _valpha + _vbeta * _vc);
} }
for (int m = (M & 0xfffffffc); m < M - 3; ++m) { for (int m = (M & 0xfffffffc); m < M; ++m) {
float _sum0 = *(buf_c + m); float _sum0 = *(buf_c + m);
for (int tid = 1; tid < threads_num; ++tid) { for (int tid = 1; tid < threads_num; ++tid) {
_sum0 += *(buf_c + tid * M + m); _sum0 += *(buf_c + tid * M + m);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册