提交 ba6972ec 编写于 作者: 李寅

v7 asm works

上级 d7fbfee1
...@@ -30,7 +30,7 @@ namespace kernels { ...@@ -30,7 +30,7 @@ namespace kernels {
void SGemm::operator()(const MatrixMap<const float> &lhs, void SGemm::operator()(const MatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs, const MatrixMap<const float> &rhs,
MatrixMap<float> *result) { MatrixMap<float> *result) {
if (rhs.col() < 16 && lhs.row() >= 16) { if (rhs.col() < lhs.row()) {
MatrixMap<const float> lhs_transpose = lhs.transpose(); MatrixMap<const float> lhs_transpose = lhs.transpose();
MatrixMap<const float> rhs_transpose = rhs.transpose(); MatrixMap<const float> rhs_transpose = rhs.transpose();
MatrixMap<float> result_transpose = result->transpose(); MatrixMap<float> result_transpose = result->transpose();
...@@ -45,14 +45,13 @@ void SGemm::operator()(const MatrixMap<const float> &lhs, ...@@ -45,14 +45,13 @@ void SGemm::operator()(const MatrixMap<const float> &lhs,
} }
packed_ = true; packed_ = true;
PackedBlock<float> packed_result;
operator()(packed_lhs_, operator()(packed_lhs_,
packed_rhs_, packed_rhs_,
lhs.row(), lhs.row(),
lhs.col(), lhs.col(),
rhs.col(), rhs.col(),
&packed_result); &packed_result_);
UnPack(packed_result, result); UnPack(packed_result_, result);
} }
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -161,7 +160,8 @@ void SGemm::operator()(const PackedBlock<float> &lhs, ...@@ -161,7 +160,8 @@ void SGemm::operator()(const PackedBlock<float> &lhs,
#endif #endif
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
// TODO(liyin): collapse loop // TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as
// as possible to cache, by tiling lhs by height and rhs by width.
// w: 4 // w: 4
#pragma omp parallel for #pragma omp parallel for
...@@ -319,11 +319,11 @@ void SGemm::operator()(const PackedBlock<float> &lhs, ...@@ -319,11 +319,11 @@ void SGemm::operator()(const PackedBlock<float> &lhs,
c2 = vdupq_n_f32(0.f); c2 = vdupq_n_f32(0.f);
c3 = vdupq_n_f32(0.f); c3 = vdupq_n_f32(0.f);
#if defined(__aarch64__) // d: 8
block_d = remain_d >> 3; block_d = remain_d >> 3;
remain_d -= (block_d << 3); remain_d -= (block_d << 3);
// d: 8 #if defined(__aarch64__)
for (index_t bd = 0; bd < block_d; ++bd) { for (index_t bd = 0; bd < block_d; ++bd) {
// 4.8.4 // 4.8.4
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
...@@ -359,12 +359,99 @@ void SGemm::operator()(const PackedBlock<float> &lhs, ...@@ -359,12 +359,99 @@ void SGemm::operator()(const PackedBlock<float> &lhs,
lhs_ptr += 32; lhs_ptr += 32;
rhs_ptr += 32; rhs_ptr += 32;
} }
#else // arm v7
// 4.8.4
if (block_d > 0) {
asm volatile(
"0: \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vld1.f32 {d24-d25}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q10, d0[0] \n"
"vmla.f32 %[c1], q10, d0[1] \n"
"vmla.f32 %[c2], q10, d1[0] \n"
"vmla.f32 %[c3], q10, d1[1] \n"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n"
"vld1.f32 {d26-d27}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q11, d2[0] \n"
"vmla.f32 %[c1], q11, d2[1] \n"
"vmla.f32 %[c2], q11, d3[0] \n"
"vmla.f32 %[c3], q11, d3[1] \n"
"vld1.f32 {d8-d9}, [%[lhs_ptr]]! \n"
"vld1.f32 {d28-d29}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q12, d4[0] \n"
"vmla.f32 %[c1], q12, d4[1] \n"
"vmla.f32 %[c2], q12, d5[0] \n"
"vmla.f32 %[c3], q12, d5[1] \n"
"vld1.f32 {d10-d11}, [%[lhs_ptr]]! \n"
"vld1.f32 {d30-d31}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q13, d6[0] \n"
"vmla.f32 %[c1], q13, d6[1] \n"
"vmla.f32 %[c2], q13, d7[0] \n"
"vmla.f32 %[c3], q13, d7[1] \n"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n"
"vmla.f32 %[c0], q14, d8[0] \n"
"vmla.f32 %[c1], q14, d8[1] \n"
"vmla.f32 %[c2], q14, d9[0] \n"
"vmla.f32 %[c3], q14, d9[1] \n"
"vmla.f32 %[c0], q15, d10[0] \n"
"vmla.f32 %[c1], q15, d10[1] \n"
"vmla.f32 %[c2], q15, d11[0] \n"
"vmla.f32 %[c3], q15, d11[1] \n"
"vmla.f32 %[c0], q10, d0[0] \n"
"vmla.f32 %[c1], q10, d0[1] \n"
"vmla.f32 %[c2], q10, d1[0] \n"
"vmla.f32 %[c3], q10, d1[1] \n"
"subs %[block_d], %[block_d], #1 \n"
"vmla.f32 %[c0], q11, d2[0] \n"
"vmla.f32 %[c1], q11, d2[1] \n"
"vmla.f32 %[c2], q11, d3[0] \n"
"vmla.f32 %[c3], q11, d3[1] \n"
"bne 0b \n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr),
[rhs_ptr] "+r"(rhs_ptr),
[res_ptr] "+r"(res_ptr),
[block_d] "+r"(block_d),
[c0] "+w"(c0),
[c1] "+w"(c1),
[c2] "+w"(c2),
[c3] "+w"(c3)
: // inputs
: // clabbers
"cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5",
"q10", "q11", "q12", "q13", "q14", "q15");
}
#endif // __aarch64__ #endif // __aarch64__
// d: 4
block_d = remain_d >> 2; block_d = remain_d >> 2;
remain_d -= (block_d << 2); remain_d -= (block_d << 2);
// d: 4
for (index_t bd = 0; bd < block_d; ++bd) { for (index_t bd = 0; bd < block_d; ++bd) {
// 4.4.4 // 4.4.4
float32x4_t a0, a1, a2, a3; float32x4_t a0, a1, a2, a3;
...@@ -639,8 +726,30 @@ void SGemm::operator()(const PackedBlock<float> &lhs, ...@@ -639,8 +726,30 @@ void SGemm::operator()(const PackedBlock<float> &lhs,
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
index_t block_d = 0; index_t block_d = 0;
float32x4_t c0; float32x4_t c0, c1;
c0 = vdupq_n_f32(0.f); c0 = vdupq_n_f32(0.f);
c1 = vdupq_n_f32(0.f);
block_d = remain_d >> 3;
remain_d -= (block_d << 3);
// d: 8
for (index_t bd = 0; bd < block_d; ++bd) {
// 1.8.1
float32x4_t a0, a1;
float32x4_t b0, b1;
a0 = vld1q_f32(lhs_ptr);
a1 = vld1q_f32(lhs_ptr + 4);
b0 = vld1q_f32(rhs_ptr);
b1 = vld1q_f32(rhs_ptr + 4);
c0 = vmlaq_f32(c0, a0, b0);
c1 = vmlaq_f32(c1, a1, b1);
lhs_ptr += 8;
rhs_ptr += 8;
}
block_d = remain_d >> 2; block_d = remain_d >> 2;
remain_d -= (block_d << 2); remain_d -= (block_d << 2);
...@@ -659,7 +768,8 @@ void SGemm::operator()(const PackedBlock<float> &lhs, ...@@ -659,7 +768,8 @@ void SGemm::operator()(const PackedBlock<float> &lhs,
lhs_ptr += 4; lhs_ptr += 4;
rhs_ptr += 4; rhs_ptr += 4;
} }
sum = vaddvq_f32(c0); sum += vaddvq_f32(c0);
sum += vaddvq_f32(c1);
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
// d: remain // d: remain
...@@ -699,7 +809,7 @@ void SGemm::UnPack(const PackedBlock<float> &packed_result, ...@@ -699,7 +809,7 @@ void SGemm::UnPack(const PackedBlock<float> &packed_result,
// This is for non-transposed result // This is for non-transposed result
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#pragma omp parallel for #pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) { for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height; const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw; float *unpacked_data_ptr = unpacked_data + iw;
...@@ -724,7 +834,7 @@ void SGemm::UnPack(const PackedBlock<float> &packed_result, ...@@ -724,7 +834,7 @@ void SGemm::UnPack(const PackedBlock<float> &packed_result,
// This is for transposed result // This is for transposed result
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#pragma omp parallel for #pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) { for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *packed_data_ptr = packed_data + iw * height; const float *packed_data_ptr = packed_data + iw * height;
float *unpacked_data_ptr = unpacked_data + iw * height; float *unpacked_data_ptr = unpacked_data + iw * height;
...@@ -763,7 +873,7 @@ void SGemm::Pack(const MatrixMap<const float> &src, ...@@ -763,7 +873,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
// This is for packing no-transpose lhs. // This is for packing no-transpose lhs.
index_t h = 0; index_t h = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__) #if defined(__aarch64__)
#pragma omp parallel for #pragma omp parallel for
for (index_t ih = h; ih <= height - 8; ih += 8) { for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih * width; const float *src_data_ptr = src_data + ih * width;
...@@ -809,7 +919,7 @@ void SGemm::Pack(const MatrixMap<const float> &src, ...@@ -809,7 +919,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
// This is for packing transpose-needed lhs. // This is for packing transpose-needed lhs.
index_t h = 0; index_t h = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__) #if defined(__aarch64__)
#pragma omp parallel for #pragma omp parallel for
for (index_t ih = h; ih <= height - 8; ih += 8) { for (index_t ih = h; ih <= height - 8; ih += 8) {
const float *src_data_ptr = src_data + ih; const float *src_data_ptr = src_data + ih;
...@@ -850,7 +960,7 @@ void SGemm::Pack(const MatrixMap<const float> &src, ...@@ -850,7 +960,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
// This is for packing no-transpose rhs. // This is for packing no-transpose rhs.
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#pragma omp parallel for #pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) { for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw; const float *src_data_ptr = src_data + iw;
float *packed_data_ptr = packed_data + iw * height; float *packed_data_ptr = packed_data + iw * height;
...@@ -875,7 +985,7 @@ void SGemm::Pack(const MatrixMap<const float> &src, ...@@ -875,7 +985,7 @@ void SGemm::Pack(const MatrixMap<const float> &src,
// This is for packing transpose-needed rhs. // This is for packing transpose-needed rhs.
index_t w = 0; index_t w = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#pragma omp parallel for #pragma omp parallel for
for (index_t iw = w; iw <= width - 4; iw += 4) { for (index_t iw = w; iw <= width - 4; iw += 4) {
const float *src_data_ptr = src_data + iw * height; const float *src_data_ptr = src_data + iw * height;
float *packed_data_ptr = packed_data + iw * height; float *packed_data_ptr = packed_data + iw * height;
......
...@@ -148,6 +148,7 @@ class SGemm { ...@@ -148,6 +148,7 @@ class SGemm {
PackedBlock<float> packed_lhs_; PackedBlock<float> packed_lhs_;
PackedBlock<float> packed_rhs_; PackedBlock<float> packed_rhs_;
PackedBlock<float> packed_result_;
bool packed_; bool packed_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册