提交 3e867dca 编写于 作者: W wuchenghui

fix gemv multi-batch case

上级 6c8cc84e
......@@ -34,10 +34,10 @@ void FullyConnectedFunctor<DeviceType::NEON,
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr);
for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j];
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
......
......@@ -566,6 +566,7 @@ inline void GemmTile(const float *A,
}
} // namespace
// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
const float *B,
const index_t batch,
......@@ -573,6 +574,12 @@ void Gemm(const float *A,
const index_t K,
const index_t width,
float *C) {
if (width == 1) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
}
return;
}
memset(C, 0, sizeof(float) * batch * height * width);
......@@ -628,6 +635,7 @@ void Gemm(const float *A,
} // n
}
// A: height x K, B: K x width, C: height x width
void GemmRef(const float *A,
const float *B,
const index_t height,
......@@ -647,19 +655,24 @@ void GemmRef(const float *A,
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
memset(out_ptr, 0, sizeof(float) * height);
memset(out_ptr, 0, sizeof(float) * height * batch);
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h] += v_ptr[w] * m_ptr[h * width + w];
out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w];
}
}
}
}
// M: height x width, Vin: width x 1, Vout: height x 1
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
......@@ -669,14 +682,15 @@ void Gemv(const float *m_ptr,
index_t remain_w = width - (width_d4 << 2);
index_t remain_h = height - (height_d4 << 2);
for (index_t b = 0; b < batch; ++b) {
#pragma omp parallel for
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr;
float *out_ptr0 = out_ptr + h * 4;
const float *v_ptr0 = v_ptr + b * width;
float *out_ptr0 = out_ptr + h * 4 + b * height;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
......@@ -749,8 +763,9 @@ void Gemv(const float *m_ptr,
}
out_ptr[remain_start_height + h] = sum;
}
}
#else
GemvRef(m_ptr, v_ptr, width, height, out_ptr);
GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr);
#endif
}
......
......@@ -41,12 +41,14 @@ void GemmRef(const float *A,
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
......
......@@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) {
[&gen, &nd] {
return nd(gen);
});
kernels::Gemv(A.get(), B.get(), K, N, C.get());
kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get());
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get());
for (int i = 0; i < N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
......
......@@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector<index_t> &shape,
for (int kernel_size : {1, 3, 5}) {
for (int stride : {1, 2}) {
if (stride < kernel_size) {
if (stride <= kernel_size) {
func(kernel_size, kernel_size, stride, stride);
}
}
......
......@@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
FullyConnectedTestNEON(3, 1, 1, 16, 8);
FullyConnectedTestNEON(3, 7, 7, 32, 16);
}
} // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册