From 3e867dcad7600de1cd1ffbb1ed929d70063d7bb3 Mon Sep 17 00:00:00 2001 From: wuchenghui Date: Thu, 19 Apr 2018 17:59:59 +0800 Subject: [PATCH] fix gemv multi-batch case --- mace/kernels/arm/fully_connected.cc | 4 +- mace/kernels/gemm.cc | 173 +++++++++++++++------------- mace/kernels/gemm.h | 2 + mace/kernels/gemm_test.cc | 4 +- mace/ops/conv_2d_test.cc | 2 +- mace/ops/fully_connected_test.cc | 2 + 6 files changed, 103 insertions(+), 84 deletions(-) diff --git a/mace/kernels/arm/fully_connected.cc b/mace/kernels/arm/fully_connected.cc index 0944480e..5df39ce5 100644 --- a/mace/kernels/arm/fully_connected.cc +++ b/mace/kernels/arm/fully_connected.cc @@ -34,10 +34,10 @@ void FullyConnectedFunctordata(); float *output_ptr = output->mutable_data(); + Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr); for (int i = 0; i < N; ++i) { - Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr); for (int j = 0; j < output_size; ++j) { - output_ptr[j] += bias_ptr[j]; + output_ptr[j + i * output_size] += bias_ptr[j]; } } diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index cb11fa5c..b252949a 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -566,6 +566,7 @@ inline void GemmTile(const float *A, } } // namespace +// A: height x K, B: K x width, C: height x width void Gemm(const float *A, const float *B, const index_t batch, @@ -573,6 +574,12 @@ void Gemm(const float *A, const index_t K, const index_t width, float *C) { + if (width == 1) { + for (index_t b = 0; b < batch; ++b) { + Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height); + } + return; + } memset(C, 0, sizeof(float) * batch * height * width); @@ -628,6 +635,7 @@ void Gemm(const float *A, } // n } +// A: height x K, B: K x width, C: height x width void GemmRef(const float *A, const float *B, const index_t height, @@ -647,19 +655,24 @@ void GemmRef(const float *A, void GemvRef(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr) { - memset(out_ptr, 0, sizeof(float) * height); - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - out_ptr[h] += v_ptr[w] * m_ptr[h * width + w]; + memset(out_ptr, 0, sizeof(float) * height * batch); + for (int b = 0; b < batch; ++b) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w]; + } } } } +// M: height x width, Vin: width x 1, Vout: height x 1 void Gemv(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr) { @@ -669,88 +682,90 @@ void Gemv(const float *m_ptr, index_t remain_w = width - (width_d4 << 2); index_t remain_h = height - (height_d4 << 2); + for (index_t b = 0; b < batch; ++b) { #pragma omp parallel for - for (index_t h = 0; h < height_d4; ++h) { - const float *m_ptr0 = m_ptr + h * width * 4; - const float *m_ptr1 = m_ptr0 + width; - const float *m_ptr2 = m_ptr1 + width; - const float *m_ptr3 = m_ptr2 + width; - const float *v_ptr0 = v_ptr; - float *out_ptr0 = out_ptr + h * 4; - - float32x4_t vm0, vm1, vm2, vm3; - float32x4_t vv; - - float32x4_t vsum0 = vdupq_n_f32(0.f); - float32x4_t vsum1 = vdupq_n_f32(0.f); - float32x4_t vsum2 = vdupq_n_f32(0.f); - float32x4_t vsum3 = vdupq_n_f32(0.f); - - for (index_t w = 0; w < width_d4; ++w) { - vm0 = vld1q_f32(m_ptr0); - vm1 = vld1q_f32(m_ptr1); - vm2 = vld1q_f32(m_ptr2); - vm3 = vld1q_f32(m_ptr3); - vv = vld1q_f32(v_ptr0); - - vsum0 = vmlaq_f32(vsum0, vm0, vv); - vsum1 = vmlaq_f32(vsum1, vm1, vv); - vsum2 = vmlaq_f32(vsum2, vm2, vv); - vsum3 = vmlaq_f32(vsum3, vm3, vv); - - m_ptr0 += 4; - m_ptr1 += 4; - m_ptr2 += 4; - m_ptr3 += 4; - v_ptr0 += 4; - } - float sum0 = vaddvq_f32(vsum0); - float sum1 = vaddvq_f32(vsum1); - float sum2 = vaddvq_f32(vsum2); - float sum3 = vaddvq_f32(vsum3); - - // handle remaining w - for (index_t w = 0; w < remain_w; ++w) { - sum0 += m_ptr0[0] * v_ptr0[0]; - sum1 += m_ptr1[0] * v_ptr0[0]; - sum2 += m_ptr2[0] * v_ptr0[0]; - sum3 += m_ptr3[0] * v_ptr0[0]; - m_ptr0++; - m_ptr1++; - m_ptr2++; - m_ptr3++; - v_ptr0++; + for (index_t h = 0; h < height_d4; ++h) { + const float *m_ptr0 = m_ptr + h * width * 4; + const float *m_ptr1 = m_ptr0 + width; + const float *m_ptr2 = m_ptr1 + width; + const float *m_ptr3 = m_ptr2 + width; + const float *v_ptr0 = v_ptr + b * width; + float *out_ptr0 = out_ptr + h * 4 + b * height; + + float32x4_t vm0, vm1, vm2, vm3; + float32x4_t vv; + + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + float32x4_t vsum2 = vdupq_n_f32(0.f); + float32x4_t vsum3 = vdupq_n_f32(0.f); + + for (index_t w = 0; w < width_d4; ++w) { + vm0 = vld1q_f32(m_ptr0); + vm1 = vld1q_f32(m_ptr1); + vm2 = vld1q_f32(m_ptr2); + vm3 = vld1q_f32(m_ptr3); + vv = vld1q_f32(v_ptr0); + + vsum0 = vmlaq_f32(vsum0, vm0, vv); + vsum1 = vmlaq_f32(vsum1, vm1, vv); + vsum2 = vmlaq_f32(vsum2, vm2, vv); + vsum3 = vmlaq_f32(vsum3, vm3, vv); + + m_ptr0 += 4; + m_ptr1 += 4; + m_ptr2 += 4; + m_ptr3 += 4; + v_ptr0 += 4; + } + float sum0 = vaddvq_f32(vsum0); + float sum1 = vaddvq_f32(vsum1); + float sum2 = vaddvq_f32(vsum2); + float sum3 = vaddvq_f32(vsum3); + + // handle remaining w + for (index_t w = 0; w < remain_w; ++w) { + sum0 += m_ptr0[0] * v_ptr0[0]; + sum1 += m_ptr1[0] * v_ptr0[0]; + sum2 += m_ptr2[0] * v_ptr0[0]; + sum3 += m_ptr3[0] * v_ptr0[0]; + m_ptr0++; + m_ptr1++; + m_ptr2++; + m_ptr3++; + v_ptr0++; + } + *out_ptr0++ = sum0; + *out_ptr0++ = sum1; + *out_ptr0++ = sum2; + *out_ptr0++ = sum3; } - *out_ptr0++ = sum0; - *out_ptr0++ = sum1; - *out_ptr0++ = sum2; - *out_ptr0++ = sum3; - } - // handle remaining h - index_t remain_start_height = height_d4 << 2; + // handle remaining h + index_t remain_start_height = height_d4 << 2; #pragma omp parallel for - for (index_t h = 0; h < remain_h; ++h) { - float32x4_t vsum0 = vdupq_n_f32(0.f); - const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; - const float *v_ptr0 = v_ptr; - for (index_t w = 0; w < width_d4; ++w) { - float32x4_t vm = vld1q_f32(m_ptr0); - float32x4_t vv = vld1q_f32(v_ptr0); - vsum0 = vmlaq_f32(vsum0, vm, vv); - m_ptr0 += 4; - v_ptr0 += 4; - } - float sum = vaddvq_f32(vsum0); - for (index_t w = 0; w < remain_w; ++w) { - sum += m_ptr0[0] * v_ptr0[0]; - m_ptr0++; - v_ptr0++; + for (index_t h = 0; h < remain_h; ++h) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; + const float *v_ptr0 = v_ptr; + for (index_t w = 0; w < width_d4; ++w) { + float32x4_t vm = vld1q_f32(m_ptr0); + float32x4_t vv = vld1q_f32(v_ptr0); + vsum0 = vmlaq_f32(vsum0, vm, vv); + m_ptr0 += 4; + v_ptr0 += 4; + } + float sum = vaddvq_f32(vsum0); + for (index_t w = 0; w < remain_w; ++w) { + sum += m_ptr0[0] * v_ptr0[0]; + m_ptr0++; + v_ptr0++; + } + out_ptr[remain_start_height + h] = sum; } - out_ptr[remain_start_height + h] = sum; } #else - GemvRef(m_ptr, v_ptr, width, height, out_ptr); + GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr); #endif } diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index ba4f812d..e1fcfad6 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -41,12 +41,14 @@ void GemmRef(const float *A, void Gemv(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr); void GemvRef(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr); diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 217543ed..8400ca85 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) { [&gen, &nd] { return nd(gen); }); - kernels::Gemv(A.get(), B.get(), K, N, C.get()); - kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get()); + kernels::Gemv(A.get(), B.get(), 1, K, N, C.get()); + kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get()); for (int i = 0; i < N; ++i) { EXPECT_NEAR(C_ref[i], C[i], 0.1); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 219e4af3..41a6546a 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector &shape, for (int kernel_size : {1, 3, 5}) { for (int stride : {1, 2}) { - if (stride < kernel_size) { + if (stride <= kernel_size) { func(kernel_size, kernel_size, stride, stride); } } diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index e994213a..daef7402 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) { FullyConnectedTestNEON(1, 7, 7, 32, 16); FullyConnectedTestNEON(1, 7, 7, 512, 128); FullyConnectedTestNEON(1, 1, 1, 2048, 1024); + FullyConnectedTestNEON(3, 1, 1, 16, 8); + FullyConnectedTestNEON(3, 7, 7, 32, 16); } } // namespace test -- GitLab