提交 04f826f1 编写于 作者: H hjchen2

Fix multi-thread bug for gemv when n==1

上级 53956140
......@@ -332,7 +332,6 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
uint32_t mask[4] = {0, 1, 2, 3};
int remain_n = N & 0x3;
uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
float32x4_t _sum0, _sum1, _sum2, _sum3;
float32x4_t _valpha = vdupq_n_f32(alpha);
#pragma omp parallel for
......@@ -342,11 +341,12 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
const float *in2 = in1 + lda;
const float *in3 = in2 + lda;
float *output = C + m;
float32x4_t _sum0, _sum1, _sum2, _sum3;
_sum0 = vdupq_n_f32(0.f);
_sum1 = vdupq_n_f32(0.f);
_sum2 = vdupq_n_f32(0.f);
_sum3 = vdupq_n_f32(0.f);
int n = 0;
for (; n < N - 3; n += 4) {
float32x4_t _r0 = vld1q_f32(in0 + n);
......@@ -390,7 +390,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
for (int m = (M & 0xfffc); m < M; ++m) {
const float *in0 = A + m * lda;
float *output = C + m;
_sum0 = vdupq_n_f32(0.f);
float32x4_t _sum0 = vdupq_n_f32(0.f);
int n = 0;
for (; n < N - 3; n += 4) {
......@@ -408,7 +408,7 @@ void sgemv_notrans_mx1(const int M, const int N, const float alpha,
_sum0 = vpaddq_f32(_sum0, _sum0);
_sum0 = vmulq_f32(_sum0, _valpha);
if (beta != 0.f) {
_sum2 = vmulq_n_f32(vld1q_f32(output), beta);
float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta);
_sum0 = vpaddq_f32(_sum0, _sum2);
}
// restore
......
......@@ -206,11 +206,11 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
float gap = abs(output_data[i] - output_cmp_data[i]);
// PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3,
// "output[%d] = %d, output_cmp[%d] = %d", i,
// output_data[i], i, output_cmp_data[i]);
if (gap > 1e-2 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
if (gap > 1e-2 && abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
std::cerr << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i << "] = " << output_cmp_data[i]
<< std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册