提交 04f826f1 编写于 作者: H hjchen2

Fix multi-thread bug for gemv when n==1

上级 53956140
...@@ -332,7 +332,6 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, ...@@ -332,7 +332,6 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
uint32_t mask[4] = {0, 1, 2, 3}; uint32_t mask[4] = {0, 1, 2, 3};
int remain_n = N & 0x3; int remain_n = N & 0x3;
uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
float32x4_t _sum0, _sum1, _sum2, _sum3;
float32x4_t _valpha = vdupq_n_f32(alpha); float32x4_t _valpha = vdupq_n_f32(alpha);
#pragma omp parallel for #pragma omp parallel for
...@@ -342,11 +341,12 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, ...@@ -342,11 +341,12 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
const float *in2 = in1 + lda; const float *in2 = in1 + lda;
const float *in3 = in2 + lda; const float *in3 = in2 + lda;
float *output = C + m; float *output = C + m;
float32x4_t _sum0, _sum1, _sum2, _sum3;
_sum0 = vdupq_n_f32(0.f); _sum0 = vdupq_n_f32(0.f);
_sum1 = vdupq_n_f32(0.f); _sum1 = vdupq_n_f32(0.f);
_sum2 = vdupq_n_f32(0.f); _sum2 = vdupq_n_f32(0.f);
_sum3 = vdupq_n_f32(0.f); _sum3 = vdupq_n_f32(0.f);
int n = 0; int n = 0;
for (; n < N - 3; n += 4) { for (; n < N - 3; n += 4) {
float32x4_t _r0 = vld1q_f32(in0 + n); float32x4_t _r0 = vld1q_f32(in0 + n);
...@@ -390,7 +390,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, ...@@ -390,7 +390,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
for (int m = (M & 0xfffc); m < M; ++m) { for (int m = (M & 0xfffc); m < M; ++m) {
const float *in0 = A + m * lda; const float *in0 = A + m * lda;
float *output = C + m; float *output = C + m;
_sum0 = vdupq_n_f32(0.f); float32x4_t _sum0 = vdupq_n_f32(0.f);
int n = 0; int n = 0;
for (; n < N - 3; n += 4) { for (; n < N - 3; n += 4) {
...@@ -408,7 +408,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha, ...@@ -408,7 +408,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
_sum0 = vpaddq_f32(_sum0, _sum0); _sum0 = vpaddq_f32(_sum0, _sum0);
_sum0 = vmulq_f32(_sum0, _valpha); _sum0 = vmulq_f32(_sum0, _valpha);
if (beta != 0.f) { if (beta != 0.f) {
_sum2 = vmulq_n_f32(vld1q_f32(output), beta); float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta);
_sum0 = vpaddq_f32(_sum0, _sum2); _sum0 = vpaddq_f32(_sum0, _sum2);
} }
// restore // restore
......
...@@ -206,11 +206,11 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, ...@@ -206,11 +206,11 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
const Otype *output_data = output->data<Otype>(); const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>(); Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i]; float gap = abs(output_data[i] - output_cmp_data[i]);
// PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3, // PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3,
// "output[%d] = %d, output_cmp[%d] = %d", i, // "output[%d] = %d, output_cmp[%d] = %d", i,
// output_data[i], i, output_cmp_data[i]); // output_data[i], i, output_cmp_data[i]);
if (gap > 1e-2 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { if (gap > 1e-2 && abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
std::cerr << "output_data[" << i << "] = " << output_data[i] std::cerr << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i << "] = " << output_cmp_data[i] << ", output_cmp_data[" << i << "] = " << output_cmp_data[i]
<< std::endl; << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册