diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc index 5cf3264e2ac3f70750cba292ccc1c1c07831c99f..49c2717815c49d6fcc96ba7f5ef65d08cade8f3b 100644 --- a/mace/kernels/sgemm.cc +++ b/mace/kernels/sgemm.cc @@ -30,7 +30,7 @@ namespace kernels { void SGemm::operator()(const MatrixMap &lhs, const MatrixMap &rhs, MatrixMap *result) { - if (rhs.col() < 16 && lhs.row() >= 16) { + if (rhs.col() < lhs.row()) { MatrixMap lhs_transpose = lhs.transpose(); MatrixMap rhs_transpose = rhs.transpose(); MatrixMap result_transpose = result->transpose(); @@ -45,14 +45,13 @@ void SGemm::operator()(const MatrixMap &lhs, } packed_ = true; - PackedBlock packed_result; operator()(packed_lhs_, packed_rhs_, lhs.row(), lhs.col(), rhs.col(), - &packed_result); - UnPack(packed_result, result); + &packed_result_); + UnPack(packed_result_, result); } #if defined(MACE_ENABLE_NEON) @@ -161,7 +160,8 @@ void SGemm::operator()(const PackedBlock &lhs, #endif #if defined(MACE_ENABLE_NEON) - // TODO(liyin): collapse loop + // TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as + // as possible to cache, by tiling lhs by height and rhs by width. // w: 4 #pragma omp parallel for @@ -319,11 +319,11 @@ void SGemm::operator()(const PackedBlock &lhs, c2 = vdupq_n_f32(0.f); c3 = vdupq_n_f32(0.f); -#if defined(__aarch64__) + // d: 8 block_d = remain_d >> 3; remain_d -= (block_d << 3); - // d: 8 +#if defined(__aarch64__) for (index_t bd = 0; bd < block_d; ++bd) { // 4.8.4 float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; @@ -359,12 +359,99 @@ void SGemm::operator()(const PackedBlock &lhs, lhs_ptr += 32; rhs_ptr += 32; } +#else // arm v7 + // 4.8.4 + if (block_d > 0) { + asm volatile( + "0: \n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + "vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n" + + "vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n" + "vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n" + "vld1.f32 {d24-d25}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q10, d0[0] \n" + "vmla.f32 %[c1], q10, d0[1] \n" + "vmla.f32 %[c2], q10, d1[0] \n" + "vmla.f32 %[c3], q10, d1[1] \n" + + "vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n" + "vld1.f32 {d26-d27}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q11, d2[0] \n" + "vmla.f32 %[c1], q11, d2[1] \n" + "vmla.f32 %[c2], q11, d3[0] \n" + "vmla.f32 %[c3], q11, d3[1] \n" + + "vld1.f32 {d8-d9}, [%[lhs_ptr]]! \n" + "vld1.f32 {d28-d29}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q12, d4[0] \n" + "vmla.f32 %[c1], q12, d4[1] \n" + "vmla.f32 %[c2], q12, d5[0] \n" + "vmla.f32 %[c3], q12, d5[1] \n" + + "vld1.f32 {d10-d11}, [%[lhs_ptr]]! \n" + "vld1.f32 {d30-d31}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q13, d6[0] \n" + "vmla.f32 %[c1], q13, d6[1] \n" + "vmla.f32 %[c2], q13, d7[0] \n" + "vmla.f32 %[c3], q13, d7[1] \n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + + "vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n" + "vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q14, d8[0] \n" + "vmla.f32 %[c1], q14, d8[1] \n" + "vmla.f32 %[c2], q14, d9[0] \n" + "vmla.f32 %[c3], q14, d9[1] \n" + + "vmla.f32 %[c0], q15, d10[0] \n" + "vmla.f32 %[c1], q15, d10[1] \n" + "vmla.f32 %[c2], q15, d11[0] \n" + "vmla.f32 %[c3], q15, d11[1] \n" + + "vmla.f32 %[c0], q10, d0[0] \n" + "vmla.f32 %[c1], q10, d0[1] \n" + "vmla.f32 %[c2], q10, d1[0] \n" + "vmla.f32 %[c3], q10, d1[1] \n" + + "subs %[block_d], %[block_d], #1 \n" + + "vmla.f32 %[c0], q11, d2[0] \n" + "vmla.f32 %[c1], q11, d2[1] \n" + "vmla.f32 %[c2], q11, d3[0] \n" + "vmla.f32 %[c3], q11, d3[1] \n" + + "bne 0b \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [res_ptr] "+r"(res_ptr), + [block_d] "+r"(block_d), + [c0] "+w"(c0), + [c1] "+w"(c1), + [c2] "+w"(c2), + [c3] "+w"(c3) + : // inputs + : // clabbers + "cc", "memory", + "q0", "q1", "q2", "q3", "q4", "q5", + "q10", "q11", "q12", "q13", "q14", "q15"); + } #endif // __aarch64__ + // d: 4 block_d = remain_d >> 2; remain_d -= (block_d << 2); - // d: 4 for (index_t bd = 0; bd < block_d; ++bd) { // 4.4.4 float32x4_t a0, a1, a2, a3; @@ -639,8 +726,30 @@ void SGemm::operator()(const PackedBlock &lhs, #if defined(MACE_ENABLE_NEON) index_t block_d = 0; - float32x4_t c0; + float32x4_t c0, c1; c0 = vdupq_n_f32(0.f); + c1 = vdupq_n_f32(0.f); + + block_d = remain_d >> 3; + remain_d -= (block_d << 3); + + // d: 8 + for (index_t bd = 0; bd < block_d; ++bd) { + // 1.8.1 + float32x4_t a0, a1; + float32x4_t b0, b1; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + + c0 = vmlaq_f32(c0, a0, b0); + c1 = vmlaq_f32(c1, a1, b1); + + lhs_ptr += 8; + rhs_ptr += 8; + } block_d = remain_d >> 2; remain_d -= (block_d << 2); @@ -659,7 +768,8 @@ void SGemm::operator()(const PackedBlock &lhs, lhs_ptr += 4; rhs_ptr += 4; } - sum = vaddvq_f32(c0); + sum += vaddvq_f32(c0); + sum += vaddvq_f32(c1); #endif // MACE_ENABLE_NEON // d: remain @@ -699,7 +809,7 @@ void SGemm::UnPack(const PackedBlock &packed_result, // This is for non-transposed result index_t w = 0; #if defined(MACE_ENABLE_NEON) -#pragma omp parallel for + #pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *packed_data_ptr = packed_data + iw * height; float *unpacked_data_ptr = unpacked_data + iw; @@ -724,7 +834,7 @@ void SGemm::UnPack(const PackedBlock &packed_result, // This is for transposed result index_t w = 0; #if defined(MACE_ENABLE_NEON) -#pragma omp parallel for + #pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *packed_data_ptr = packed_data + iw * height; float *unpacked_data_ptr = unpacked_data + iw * height; @@ -763,7 +873,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing no-transpose lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) -#if defined(__aarch64__) + #if defined(__aarch64__) #pragma omp parallel for for (index_t ih = h; ih <= height - 8; ih += 8) { const float *src_data_ptr = src_data + ih * width; @@ -809,7 +919,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing transpose-needed lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) -#if defined(__aarch64__) + #if defined(__aarch64__) #pragma omp parallel for for (index_t ih = h; ih <= height - 8; ih += 8) { const float *src_data_ptr = src_data + ih; @@ -850,7 +960,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing no-transpose rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) -#pragma omp parallel for + #pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *src_data_ptr = src_data + iw; float *packed_data_ptr = packed_data + iw * height; @@ -875,7 +985,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing transpose-needed rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) -#pragma omp parallel for + #pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *src_data_ptr = src_data + iw * height; float *packed_data_ptr = packed_data + iw * height; diff --git a/mace/kernels/sgemm.h b/mace/kernels/sgemm.h index 263aed80029f43994c48ee119668b97df1f0b71b..daed206a349d86d119637b6d367b8cfa0d8dc5c8 100644 --- a/mace/kernels/sgemm.h +++ b/mace/kernels/sgemm.h @@ -148,6 +148,7 @@ class SGemm { PackedBlock packed_lhs_; PackedBlock packed_rhs_; + PackedBlock packed_result_; bool packed_; };