From b158770b25a58420c2d2f2efbd09eb8aedc97037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Thu, 26 Apr 2018 16:53:21 +0800 Subject: [PATCH] Optimize gemm v7 output pipeline --- mace/kernels/gemm.cc | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 572bbff8..8cf397ed 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -471,7 +471,7 @@ inline void GemmTile(const float *A, // TODO(liyin): asm v7 prefetch and load optimization while (nw--) { float32x4_t b0, b1, b2, b3; - float32x4_t c0; + float32x4_t c0, c1, c2, c3; c0 = vld1q_f32(c_ptr0); @@ -480,36 +480,37 @@ inline void GemmTile(const float *A, b2 = vld1q_f32(b_ptr2); b3 = vld1q_f32(b_ptr3); + c1 = vld1q_f32(c_ptr1); + c2 = vld1q_f32(c_ptr2); + c3 = vld1q_f32(c_ptr3); + c0 = vmlaq_lane_f32(c0, b0, a00, 0); c0 = vmlaq_lane_f32(c0, b1, a00, 1); c0 = vmlaq_lane_f32(c0, b2, a01, 0); c0 = vmlaq_lane_f32(c0, b3, a01, 1); vst1q_f32(c_ptr0, c0); - c0 = vld1q_f32(c_ptr1); - c0 = vmlaq_lane_f32(c0, b0, a10, 0); - c0 = vmlaq_lane_f32(c0, b1, a10, 1); - c0 = vmlaq_lane_f32(c0, b2, a11, 0); - c0 = vmlaq_lane_f32(c0, b3, a11, 1); + c1 = vmlaq_lane_f32(c1, b0, a10, 0); + c1 = vmlaq_lane_f32(c1, b1, a10, 1); + c1 = vmlaq_lane_f32(c1, b2, a11, 0); + c1 = vmlaq_lane_f32(c1, b3, a11, 1); - vst1q_f32(c_ptr1, c0); - c0 = vld1q_f32(c_ptr2); + vst1q_f32(c_ptr1, c1); - c0 = vmlaq_lane_f32(c0, b0, a20, 0); - c0 = vmlaq_lane_f32(c0, b1, a20, 1); - c0 = vmlaq_lane_f32(c0, b2, a21, 0); - c0 = vmlaq_lane_f32(c0, b3, a21, 1); + c2 = vmlaq_lane_f32(c2, b0, a20, 0); + c2 = vmlaq_lane_f32(c2, b1, a20, 1); + c2 = vmlaq_lane_f32(c2, b2, a21, 0); + c2 = vmlaq_lane_f32(c2, b3, a21, 1); - vst1q_f32(c_ptr2, c0); - c0 = vld1q_f32(c_ptr3); + vst1q_f32(c_ptr2, c2); - c0 = vmlaq_lane_f32(c0, b0, a30, 0); - c0 = vmlaq_lane_f32(c0, b1, a30, 1); - c0 = vmlaq_lane_f32(c0, b2, a31, 0); - c0 = vmlaq_lane_f32(c0, b3, a31, 1); + c3 = vmlaq_lane_f32(c3, b0, a30, 0); + c3 = vmlaq_lane_f32(c3, b1, a30, 1); + c3 = vmlaq_lane_f32(c3, b2, a31, 0); + c3 = vmlaq_lane_f32(c3, b3, a31, 1); - vst1q_f32(c_ptr3, c0); + vst1q_f32(c_ptr3, c3); b_ptr0 += 4; b_ptr1 += 4; @@ -586,7 +587,9 @@ void Gemm(const float *A, // It is better to use large block size if it fits for fast cache. // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), // the block size should be sqrt(32k / sizeof(T) / 3). - const index_t block_size = 48; + // As number of input channels of convolution is normally power of 2, and + // we have not optimized tiling remains, we use the following magic number + const index_t block_size = 64; const index_t block_tile_height = RoundUpDiv(height, block_size); const index_t block_tile_width = RoundUpDiv(width, block_size); const index_t block_tile_k = RoundUpDiv(K, block_size); -- GitLab