diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 40c90f58c0a4c6c8fda054194b6a5cced71cece6..0e05106fe0d6ef492c20f53b6afca9008445b062 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -50,7 +50,7 @@ inline void GemmBlock(const float *A, #if defined(MACE_ENABLE_NEON) #if defined(__aarch64__) -#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ +#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ @@ -60,7 +60,7 @@ inline void GemmBlock(const float *A, c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); #else -#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ +#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \ c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \ c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \ @@ -72,6 +72,283 @@ inline void GemmBlock(const float *A, #endif #endif +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) +#define MACE_GEMM_PART_CAL_4(RC) \ + c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \ + c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \ + c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \ + c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3); +#else +#define MACE_GEMM_PART_CAL_4(RC) \ + c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \ + c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1); +#endif +#endif + +inline void Gemm144(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + MACE_UNUSED(stride_a); + MACE_UNUSED(stride_c); + float32x4_t a0; + float32x4_t b0, b1, b2, b3; + float32x4_t c0; + + a0 = vld1q_f32(a_ptr); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + + MACE_GEMM_PART_CAL_4(0); + + vst1q_f32(c_ptr, c0); +#else + GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm244(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); +#else + GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm344(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); +#else + GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm444(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); +#else + GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm544(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3, c4; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + MACE_GEMM_PART_CAL_4(4); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); +#else + GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm644(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3, c4, c5; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + a5 = vld1q_f32(a_ptr + 5 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + c5 = vld1q_f32(c_ptr + 5 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + MACE_GEMM_PART_CAL_4(4); + MACE_GEMM_PART_CAL_4(5); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); + vst1q_f32(c_ptr + 5 * stride_c, c5); +#else + GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void GemmX44(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr, + int row) { + switch (row) { + case 1: + Gemm144(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 2: + Gemm244(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 3: + Gemm344(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 4: + Gemm444(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 5: + Gemm544(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 6: + Gemm644(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + default: + MACE_NOT_IMPLEMENTED; + } +} + inline void Gemm884(const float *a_ptr, const float *b_ptr, const index_t stride_a, @@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr, c6 = vld1q_f32(c_ptr + 6 * stride_c); c7 = vld1q_f32(c_ptr + 7 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); - MACE_GEMM_PART_CAL(7, 14, 15); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); - MACE_GEMM_PART_CAL(7, 14, 15); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); + MACE_GEMM_PART_CAL_8(6, 12, 13); + MACE_GEMM_PART_CAL_8(7, 14, 15); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr, c0 = vld1q_f32(c_ptr); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); -#else - MACE_GEMM_PART_CAL(0, 0, 1); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); vst1q_f32(c_ptr, c0); #else @@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr, c0 = vld1q_f32(c_ptr); c1 = vld1q_f32(c_ptr + 1 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr, c1 = vld1q_f32(c_ptr + 1 * stride_c); c2 = vld1q_f32(c_ptr + 2 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr, c2 = vld1q_f32(c_ptr + 2 * stride_c); c3 = vld1q_f32(c_ptr + 3 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr, c3 = vld1q_f32(c_ptr + 3 * stride_c); c4 = vld1q_f32(c_ptr + 4 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr, c4 = vld1q_f32(c_ptr + 4 * stride_c); c5 = vld1q_f32(c_ptr + 5 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr, c5 = vld1q_f32(c_ptr + 5 * stride_c); c6 = vld1q_f32(c_ptr + 6 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); + MACE_GEMM_PART_CAL_8(6, 12, 13); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -589,9 +806,19 @@ inline void GemmTile(const float *A, const index_t stride_c, float *C) { #if defined(MACE_ENABLE_NEON) - index_t h, w, k; - for (h = 0; h < height - 7; h += 8) { - for (k = 0; k < K - 7; k += 8) { + index_t h = 0; + index_t w = 0; + index_t k = 0; +#if defined(__aarch64__) + int reg_height_tile = 8; + int reg_K_tile = 8; +#else + int reg_height_tile = 6; + int reg_K_tile = 4; +#endif + + for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) { + for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) { const float *a_ptr = A + (h * stride_a + k); #if defined(__aarch64__) && defined(__clang__) int nw = width >> 2; @@ -833,43 +1060,180 @@ inline void GemmTile(const float *A, w = (width >> 2) << 2; } -#else // gcc || armv7a +#elif defined(__aarch64__) // gcc for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); } -#endif // clang && armv8a +#else // armv7 + int nw = width >> 2; + if (nw > 0) { + float32x4_t a0, a1, a2, a3, a4, a5; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + a5 = vld1q_f32(a_ptr + 5 * stride_a); + + const float *b_ptr0 = B + k * stride_b; + const float *b_ptr1 = B + (k + 1) * stride_b; + const float *b_ptr2 = B + (k + 2) * stride_b; + const float *b_ptr3 = B + (k + 3) * stride_b; + + float *c_ptr0 = C + h * stride_c; + float *c_ptr1 = C + (h + 1) * stride_c; + float *c_ptr2 = C + (h + 2) * stride_c; + float *c_ptr3 = C + (h + 3) * stride_c; + float *c_ptr4 = C + (h + 4) * stride_c; + float *c_ptr5 = C + (h + 5) * stride_c; + + asm volatile( + "pld [%7, #128] \n" + "vld1.f32 {d12-d13}, [%7]! \n" + "pld [%1, #128] \n" + "vld1.f32 {d16-d17}, [%1] \n" + "pld [%2, #128] \n" + "vld1.f32 {d18-d19}, [%2] \n" + + "0: \n" + + "pld [%3, #128] \n" + "vld1.f32 {d20-d21}, [%3] \n" + "pld [%4, #128] \n" + "vld1.f32 {d22-d23}, [%4] \n" + "pld [%5, #128] \n" + "vld1.f32 {d24-d25}, [%5] \n" + "pld [%6, #128] \n" + "vld1.f32 {d26-d27}, [%6] \n" + + "pld [%8, #128] \n" + "vld1.f32 {d14-d15}, [%8]! \n" + + "vmla.f32 q8, q6, %e22[0] \n" + "vmla.f32 q9, q6, %e23[0] \n" + "vmla.f32 q10, q6, %e24[0] \n" + "vmla.f32 q11, q6, %e25[0] \n" + "vmla.f32 q12, q6, %e26[0] \n" + "vmla.f32 q13, q6, %e27[0] \n" + + "pld [%9, #128] \n" + "vld1.f32 {d12-d13}, [%9]! \n" + + "vmla.f32 q8, q7, %e22[1] \n" + "vmla.f32 q9, q7, %e23[1] \n" + "vmla.f32 q10, q7, %e24[1] \n" + "vmla.f32 q11, q7, %e25[1] \n" + "vmla.f32 q12, q7, %e26[1] \n" + "vmla.f32 q13, q7, %e27[1] \n" + + "pld [%10, #128] \n" + "vld1.f32 {d14-d15}, [%10]! \n" + + "vmla.f32 q8, q6, %f22[0] \n" + "vmla.f32 q9, q6, %f23[0] \n" + "vmla.f32 q10, q6, %f24[0] \n" + "vmla.f32 q11, q6, %f25[0] \n" + "vmla.f32 q12, q6, %f26[0] \n" + "vmla.f32 q13, q6, %f27[0] \n" + + "vmla.f32 q8, q7, %f22[1] \n" + "vmla.f32 q9, q7, %f23[1] \n" + "vmla.f32 q10, q7, %f24[1] \n" + "vmla.f32 q11, q7, %f25[1] \n" + "vmla.f32 q12, q7, %f26[1] \n" + "vmla.f32 q13, q7, %f27[1] \n" + + "vst1.f32 {d16-d17}, [%1]! \n" + "vst1.f32 {d18-d19}, [%2]! \n" + + "pld [%7, #128] \n" + "vld1.f32 {d12-d13}, [%7]! \n" + + "vst1.f32 {d20-d21}, [%3]! \n" + "vst1.f32 {d22-d23}, [%4]! \n" + + "pld [%1, #128] \n" + "vld1.f32 {d16-d17}, [%1] \n" + + "vst1.f32 {d24-d25}, [%5]! \n" + "vst1.f32 {d26-d27}, [%6]! \n" + + "pld [%2, #128] \n" + "vld1.f32 {d18-d19}, [%2] \n" + + "subs %0, #1 \n" + "bne 0b \n" + : "=r"(nw), // 0 + "=r"(c_ptr0), // 1 + "=r"(c_ptr1), // 2 + "=r"(c_ptr2), // 3 + "=r"(c_ptr3), // 4 + "=r"(c_ptr4), // 5 + "=r"(c_ptr5), // 6 + "=r"(b_ptr0), // 7 + "=r"(b_ptr1), // 8 + "=r"(b_ptr2), // 9 + "=r"(b_ptr3) // 10 + : "0"(nw), // 11 + "1"(c_ptr0), // 12 + "2"(c_ptr1), // 13 + "3"(c_ptr2), // 14 + "4"(c_ptr3), // 15 + "5"(c_ptr4), // 16 + "6"(c_ptr5), // 17 + "7"(b_ptr0), // 18 + "8"(b_ptr1), // 19 + "9"(b_ptr2), // 20 + "10"(b_ptr3), // 21 + "w"(a0), // 22 + "w"(a1), // 23 + "w"(a2), // 24 + "w"(a3), // 25 + "w"(a4), // 26 + "w"(a5) // 27 + : "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", + "q13", "q14", "q15"); + + w = (width >> 2) << 2; + } +#endif if (w < width) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c, - c_ptr); + GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w, + stride_a, stride_b, stride_c, c_ptr); } } if (k < K) { const float *a_ptr = A + (h * stride_a + k); const float *b_ptr = B + k * stride_b; float *c_ptr = C + h * stride_c; - GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c, - c_ptr); + GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b, + stride_c, c_ptr); } } if (h < height) { index_t remain_h = height - h; - for (k = 0; k < K - 7; k += 8) { + for (k = 0; k < K - reg_K_tile; k += reg_K_tile) { const float *a_ptr = A + (h * stride_a + k); index_t w; for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); +#if defined(__aarch64__) GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); +#else + GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); +#endif } if (w < width) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b, - stride_c, c_ptr); + GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a, + stride_b, stride_c, c_ptr); } } if (k < K) { diff --git a/mace/ops/shape_test.cc b/mace/ops/shape_test.cc index 5798be7f8309970445cb3c8bf10e6327c2f52144..08ccb88b86958bb4fdbd3a1677fe1b728355f5fe 100644 --- a/mace/ops/shape_test.cc +++ b/mace/ops/shape_test.cc @@ -38,7 +38,9 @@ void TestShapeOp(const std::vector &input_shape) { std::vector expected_input_shape(input_shape.begin(), input_shape.end()); if (!expected_input_shape.empty()) { - net.AddInputFromArray("ExpectedOutput", {input_shape.size()}, + net.AddInputFromArray("ExpectedOutput", + {static_cast( + input_shape.size())}, expected_input_shape); } else { net.AddInputFromArray("ExpectedOutput", {}, {0}); diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 2aa4af2820488a7ea7fb0a293f05e7b7ad1802bf..6cd46f4e110e0cb001932a18db6db2a5c69d866b 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -37,11 +37,18 @@ void TestSlice(const std::vector &input_shape, const std::vector &output) { OpsTestNet net; net.AddInputFromArray("Input", input_shape, input); - net.AddInputFromArray("BeginIndices", {input_shape.size()}, + net.AddInputFromArray("BeginIndices", + {static_cast( + input_shape.size())}, begin_indices); - net.AddInputFromArray("EndIndices", {input_shape.size()}, + net.AddInputFromArray("EndIndices", + {static_cast( + input_shape.size())}, end_indices); - net.AddInputFromArray("Strides", {input_shape.size()}, strides); + net.AddInputFromArray("Strides", + {static_cast( + input_shape.size())}, + strides); OpDefBuilder("StridedSlice", "StridedSliceOpTest") .Input("Input")