提交 f6ac7dd2 编写于 作者: Z zp7 提交者: Yanzhan Yang

fix gemm function sgemv_mx1 (#1743)

上级 dd381236
......@@ -407,14 +407,10 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
_b = vandq_f32_u32(_b, vmask);
_sum0 = vmlaq_f32(_sum0, _r0, _b);
}
_sum0 = vpaddq_f32(_sum0, _sum0);
_sum0 = vmulq_f32(_sum0, _valpha);
if (beta != 0.f) {
float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta);
_sum0 = vpaddq_f32(_sum0, _sum2);
}
// restore
*output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1);
float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
float32x2_t _sss2 = vpadd_f32(_ss, _ss);
*output =
vget_lane_f32(_sss2, 0) * vgetq_lane_f32(_valpha, 0) + beta * (*output);
}
}
......@@ -536,7 +532,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
} else { // beta != 0.f
float32x4_t _vbeta = vdupq_n_f32(beta);
#pragma omp parallel for
for (int m = 0; m < M; m += 4) {
for (int m = 0; m < M - 3; m += 4) {
float32x4_t _sum0 = vld1q_f32(buf_c + m);
for (int tid = 1; tid < threads_num; ++tid) {
_sum0 += vld1q_f32(buf_c + tid * M + m);
......@@ -545,7 +541,7 @@ void sgemv_trans_mx1(const int M, const int N, const float alpha,
vst1q_f32(C + m, _sum0 * _valpha + _vbeta * _vc);
}
for (int m = (M & 0xfffffffc); m < M - 3; ++m) {
for (int m = (M & 0xfffffffc); m < M; ++m) {
float _sum0 = *(buf_c + m);
for (int tid = 1; tid < threads_num; ++tid) {
_sum0 += *(buf_c + tid * M + m);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册