提交 3e867dca 编写于 作者: W wuchenghui

fix gemv multi-batch case

上级 6c8cc84e
......@@ -34,10 +34,10 @@ void FullyConnectedFunctor<DeviceType::NEON,
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr);
for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j];
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
......
......@@ -566,6 +566,7 @@ inline void GemmTile(const float *A,
}
} // namespace
// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
const float *B,
const index_t batch,
......@@ -573,6 +574,12 @@ void Gemm(const float *A,
const index_t K,
const index_t width,
float *C) {
if (width == 1) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
}
return;
}
memset(C, 0, sizeof(float) * batch * height * width);
......@@ -628,6 +635,7 @@ void Gemm(const float *A,
} // n
}
// A: height x K, B: K x width, C: height x width
void GemmRef(const float *A,
const float *B,
const index_t height,
......@@ -647,19 +655,24 @@ void GemmRef(const float *A,
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
memset(out_ptr, 0, sizeof(float) * height);
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h] += v_ptr[w] * m_ptr[h * width + w];
memset(out_ptr, 0, sizeof(float) * height * batch);
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w];
}
}
}
}
// M: height x width, Vin: width x 1, Vout: height x 1
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
......@@ -669,88 +682,90 @@ void Gemv(const float *m_ptr,
index_t remain_w = width - (width_d4 << 2);
index_t remain_h = height - (height_d4 << 2);
for (index_t b = 0; b < batch; ++b) {
#pragma omp parallel for
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr;
float *out_ptr0 = out_ptr + h * 4;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) {
vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2);
vm3 = vld1q_f32(m_ptr3);
vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv);
vsum1 = vmlaq_f32(vsum1, vm1, vv);
vsum2 = vmlaq_f32(vsum2, vm2, vv);
vsum3 = vmlaq_f32(vsum3, vm3, vv);
m_ptr0 += 4;
m_ptr1 += 4;
m_ptr2 += 4;
m_ptr3 += 4;
v_ptr0 += 4;
}
float sum0 = vaddvq_f32(vsum0);
float sum1 = vaddvq_f32(vsum1);
float sum2 = vaddvq_f32(vsum2);
float sum3 = vaddvq_f32(vsum3);
// handle remaining w
for (index_t w = 0; w < remain_w; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0];
sum3 += m_ptr3[0] * v_ptr0[0];
m_ptr0++;
m_ptr1++;
m_ptr2++;
m_ptr3++;
v_ptr0++;
for (index_t h = 0; h < height_d4; ++h) {
const float *m_ptr0 = m_ptr + h * width * 4;
const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr + b * width;
float *out_ptr0 = out_ptr + h * 4 + b * height;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) {
vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2);
vm3 = vld1q_f32(m_ptr3);
vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv);
vsum1 = vmlaq_f32(vsum1, vm1, vv);
vsum2 = vmlaq_f32(vsum2, vm2, vv);
vsum3 = vmlaq_f32(vsum3, vm3, vv);
m_ptr0 += 4;
m_ptr1 += 4;
m_ptr2 += 4;
m_ptr3 += 4;
v_ptr0 += 4;
}
float sum0 = vaddvq_f32(vsum0);
float sum1 = vaddvq_f32(vsum1);
float sum2 = vaddvq_f32(vsum2);
float sum3 = vaddvq_f32(vsum3);
// handle remaining w
for (index_t w = 0; w < remain_w; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0];
sum3 += m_ptr3[0] * v_ptr0[0];
m_ptr0++;
m_ptr1++;
m_ptr2++;
m_ptr3++;
v_ptr0++;
}
*out_ptr0++ = sum0;
*out_ptr0++ = sum1;
*out_ptr0++ = sum2;
*out_ptr0++ = sum3;
}
*out_ptr0++ = sum0;
*out_ptr0++ = sum1;
*out_ptr0++ = sum2;
*out_ptr0++ = sum3;
}
// handle remaining h
index_t remain_start_height = height_d4 << 2;
// handle remaining h
index_t remain_start_height = height_d4 << 2;
#pragma omp parallel for
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv);
m_ptr0 += 4;
v_ptr0 += 4;
}
float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) {
sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv);
m_ptr0 += 4;
v_ptr0 += 4;
}
float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) {
sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
}
out_ptr[remain_start_height + h] = sum;
}
out_ptr[remain_start_height + h] = sum;
}
#else
GemvRef(m_ptr, v_ptr, width, height, out_ptr);
GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr);
#endif
}
......
......@@ -41,12 +41,14 @@ void GemmRef(const float *A,
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
......
......@@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) {
[&gen, &nd] {
return nd(gen);
});
kernels::Gemv(A.get(), B.get(), K, N, C.get());
kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get());
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get());
for (int i = 0; i < N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
......
......@@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector<index_t> &shape,
for (int kernel_size : {1, 3, 5}) {
for (int stride : {1, 2}) {
if (stride < kernel_size) {
if (stride <= kernel_size) {
func(kernel_size, kernel_size, stride, stride);
}
}
......
......@@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
FullyConnectedTestNEON(3, 1, 1, 16, 8);
FullyConnectedTestNEON(3, 7, 7, 32, 16);
}
} // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册