diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 9b3aa599f53923925119bf21378b7ba7896ac189..477831c7acc7b94544118d05a988d63b0c49128a 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -20,11 +20,13 @@ #include #endif +#include "mace/core/macros.h" #include "mace/kernels/gemm.h" -#include "mace/utils/utils.h" #include "mace/utils/logging.h" - +#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +#endif namespace mace { namespace kernels { @@ -47,15 +49,38 @@ inline void GemmBlock(const float *A, } } -// TODO(liyin): may need implement 883 since RGB +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) +#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ + c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ + c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ + c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ + c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \ + c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \ + c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \ + c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ + c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); +#else +#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ + c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \ + c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \ + c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \ + c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1); +#endif +#endif + inline void Gemm884(const float *a_ptr, const float *b_ptr, index_t stride_k, index_t stride_w, float *c_ptr) { -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, - a15; + a15; float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; float32x4_t c0, c1, c2, c3, c4, c5, c6, c7; @@ -94,24 +119,25 @@ inline void Gemm884(const float *a_ptr, c6 = vld1q_f32(c_ptr + 6 * stride_w); c7 = vld1q_f32(c_ptr + 7 * stride_w); -#define MACE_CONV_1x1_REG_CAL(RC, RA, RAN) \ - c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ - c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ - c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ - c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \ - c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \ - c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \ - c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ - c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); - - MACE_CONV_1x1_REG_CAL(0, 0, 1); - MACE_CONV_1x1_REG_CAL(1, 2, 3); - MACE_CONV_1x1_REG_CAL(2, 4, 5); - MACE_CONV_1x1_REG_CAL(3, 6, 7); - MACE_CONV_1x1_REG_CAL(4, 8, 9); - MACE_CONV_1x1_REG_CAL(5, 10, 11); - MACE_CONV_1x1_REG_CAL(6, 12, 13); - MACE_CONV_1x1_REG_CAL(7, 14, 15); +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); + MACE_GEMM_PART_CAL(6, 12, 13); + MACE_GEMM_PART_CAL(7, 14, 15); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); + MACE_GEMM_PART_CAL(6, 12, 13); + MACE_GEMM_PART_CAL(7, 14, 15); +#endif vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_w, c1); @@ -121,12 +147,428 @@ inline void Gemm884(const float *a_ptr, vst1q_f32(c_ptr + 5 * stride_w, c5); vst1q_f32(c_ptr + 6 * stride_w, c6); vst1q_f32(c_ptr + 7 * stride_w, c7); - #else GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); #endif } +inline void Gemm184(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { + MACE_UNUSED(stride_k); +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); +#else + MACE_GEMM_PART_CAL(0, 0, 1); +#endif + + vst1q_f32(c_ptr, c0); +#else + GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm284(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); +#else + GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm384(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1, c2; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + c2 = vld1q_f32(c_ptr + 2 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 2 * stride_w, c2); +#else + GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm484(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1, c2, c3; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_k); + a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + c2 = vld1q_f32(c_ptr + 2 * stride_w); + c3 = vld1q_f32(c_ptr + 3 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 2 * stride_w, c2); + vst1q_f32(c_ptr + 3 * stride_w, c3); +#else + GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm584(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1, c2, c3, c4; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_k); + a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_k); + a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + c2 = vld1q_f32(c_ptr + 2 * stride_w); + c3 = vld1q_f32(c_ptr + 3 * stride_w); + c4 = vld1q_f32(c_ptr + 4 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 2 * stride_w, c2); + vst1q_f32(c_ptr + 3 * stride_w, c3); + vst1q_f32(c_ptr + 4 * stride_w, c4); +#else + GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm684(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1, c2, c3, c4, c5; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_k); + a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_k); + a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_k); + a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + c2 = vld1q_f32(c_ptr + 2 * stride_w); + c3 = vld1q_f32(c_ptr + 3 * stride_w); + c4 = vld1q_f32(c_ptr + 4 * stride_w); + c5 = vld1q_f32(c_ptr + 5 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 2 * stride_w, c2); + vst1q_f32(c_ptr + 3 * stride_w, c3); + vst1q_f32(c_ptr + 4 * stride_w, c4); + vst1q_f32(c_ptr + 5 * stride_w, c5); + +#else + GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void Gemm784(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t c0, c1, c2, c3, c4, c5, c6; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_k); + a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_k); + a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_k); + a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); + a12 = vld1q_f32(a_ptr + 6 * stride_k); + a13 = vld1q_f32(a_ptr + 6 * stride_k + 4); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_w); + b2 = vld1q_f32(b_ptr + 2 * stride_w); + b3 = vld1q_f32(b_ptr + 3 * stride_w); + b4 = vld1q_f32(b_ptr + 4 * stride_w); + b5 = vld1q_f32(b_ptr + 5 * stride_w); + b6 = vld1q_f32(b_ptr + 6 * stride_w); + b7 = vld1q_f32(b_ptr + 7 * stride_w); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_w); + c2 = vld1q_f32(c_ptr + 2 * stride_w); + c3 = vld1q_f32(c_ptr + 3 * stride_w); + c4 = vld1q_f32(c_ptr + 4 * stride_w); + c5 = vld1q_f32(c_ptr + 5 * stride_w); + c6 = vld1q_f32(c_ptr + 6 * stride_w); + +#if defined(__aarch64__) + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); + MACE_GEMM_PART_CAL(6, 12, 13); +#else + MACE_GEMM_PART_CAL(0, 0, 1); + MACE_GEMM_PART_CAL(1, 2, 3); + MACE_GEMM_PART_CAL(2, 4, 5); + MACE_GEMM_PART_CAL(3, 6, 7); + MACE_GEMM_PART_CAL(4, 8, 9); + MACE_GEMM_PART_CAL(5, 10, 11); + MACE_GEMM_PART_CAL(6, 12, 13); +#endif + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 2 * stride_w, c2); + vst1q_f32(c_ptr + 3 * stride_w, c3); + vst1q_f32(c_ptr + 4 * stride_w, c4); + vst1q_f32(c_ptr + 5 * stride_w, c5); + vst1q_f32(c_ptr + 6 * stride_w, c6); + +#else + GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_k, stride_w, c_ptr); +#endif +} + +inline void GemmX84(const float *a_ptr, + const float *b_ptr, + index_t stride_k, + index_t stride_w, + float *c_ptr, + int row) { + switch (row) { + case 1: + Gemm184(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 2: + Gemm284(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 3: + Gemm384(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 4: + Gemm484(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 5: + Gemm584(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 6: + Gemm684(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 7: + Gemm784(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + case 8: + Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + break; + default: + MACE_NOT_IMPLEMENTED; + } +} + inline void GemmTile(const float *A, const float *B, const index_t height, @@ -137,18 +579,15 @@ inline void GemmTile(const float *A, float *C) { #if defined(MACE_ENABLE_NEON) index_t h, w, k; -#endif - -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) for (h = 0; h < height - 7; h += 8) { for (k = 0; k < K - 7; k += 8) { const float *a_ptr = A + (h * stride_k + k); -#ifdef __clang__ +#if defined(__aarch64__) && defined(__clang__) int nw = width >> 2; if (nw > 0) { // load A - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15; + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, + a14, a15; a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); a2 = vld1q_f32(a_ptr + 1 * stride_k); @@ -185,223 +624,212 @@ inline void GemmTile(const float *A, float *c_ptr7 = C + (h + 7) * stride_w; asm volatile( - "prfm pldl1keep, [%9, #128] \n" - "ld1 {v16.4s}, [%9], #16 \n" - - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v18.4s}, [%1] \n" - - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v19.4s}, [%2] \n" - - "0: \n" - - "prfm pldl1keep, [%3, #128] \n" - "ld1 {v20.4s}, [%3] \n" - "prfm pldl1keep, [%4, #128] \n" - "ld1 {v21.4s}, [%4] \n" - "prfm pldl1keep, [%5, #128] \n" - "ld1 {v22.4s}, [%5] \n" - "prfm pldl1keep, [%6, #128] \n" - "ld1 {v23.4s}, [%6] \n" - "prfm pldl1keep, [%7, #128] \n" - "ld1 {v24.4s}, [%7] \n" - "prfm pldl1keep, [%8, #128] \n" - "ld1 {v25.4s}, [%8] \n" - "prfm pldl1keep, [%10, #128] \n" - "ld1 {v17.4s}, [%10], #16 \n" - - "fmla v18.4s, v16.4s, %34.s[0] \n" - "fmla v19.4s, v16.4s, %35.s[0] \n" - "fmla v20.4s, v16.4s, %36.s[0] \n" - "fmla v21.4s, v16.4s, %37.s[0] \n" - - "fmla v22.4s, v16.4s, %38.s[0] \n" - "fmla v23.4s, v16.4s, %39.s[0] \n" - "fmla v24.4s, v16.4s, %40.s[0] \n" - "fmla v25.4s, v16.4s, %41.s[0] \n" - - "fmla v18.4s, v17.4s, %34.s[1] \n" - "fmla v19.4s, v17.4s, %35.s[1] \n" - "fmla v20.4s, v17.4s, %36.s[1] \n" - "fmla v21.4s, v17.4s, %37.s[1] \n" - - "prfm pldl1keep, [%11, #128] \n" - "ld1 {v16.4s}, [%11], #16 \n" - - "fmla v22.4s, v17.4s, %38.s[1] \n" - "fmla v23.4s, v17.4s, %39.s[1] \n" - "fmla v24.4s, v17.4s, %40.s[1] \n" - "fmla v25.4s, v17.4s, %41.s[1] \n" - - "fmla v18.4s, v16.4s, %34.s[2] \n" - "fmla v19.4s, v16.4s, %35.s[2] \n" - "fmla v20.4s, v16.4s, %36.s[2] \n" - "fmla v21.4s, v16.4s, %37.s[2] \n" - - "prfm pldl1keep, [%12, #128] \n" - "ld1 {v17.4s}, [%12], #16 \n" - - "fmla v22.4s, v16.4s, %38.s[2] \n" - "fmla v23.4s, v16.4s, %39.s[2] \n" - "fmla v24.4s, v16.4s, %40.s[2] \n" - "fmla v25.4s, v16.4s, %41.s[2] \n" - - "fmla v18.4s, v17.4s, %34.s[3] \n" - "fmla v19.4s, v17.4s, %35.s[3] \n" - "fmla v20.4s, v17.4s, %36.s[3] \n" - "fmla v21.4s, v17.4s, %37.s[3] \n" - - "prfm pldl1keep, [%13, #128] \n" - "ld1 {v16.4s}, [%13], #16 \n" - - "fmla v22.4s, v17.4s, %38.s[3] \n" - "fmla v23.4s, v17.4s, %39.s[3] \n" - "fmla v24.4s, v17.4s, %40.s[3] \n" - "fmla v25.4s, v17.4s, %41.s[3] \n" - - "fmla v18.4s, v16.4s, %42.s[0] \n" - "fmla v19.4s, v16.4s, %43.s[0] \n" - "fmla v20.4s, v16.4s, %44.s[0] \n" - "fmla v21.4s, v16.4s, %45.s[0] \n" - - "prfm pldl1keep, [%14, #128] \n" - "ld1 {v17.4s}, [%14], #16 \n" - - "fmla v22.4s, v16.4s, %46.s[0] \n" - "fmla v23.4s, v16.4s, %47.s[0] \n" - "fmla v24.4s, v16.4s, %48.s[0] \n" - "fmla v25.4s, v16.4s, %49.s[0] \n" - - "fmla v18.4s, v17.4s, %42.s[1] \n" - "fmla v19.4s, v17.4s, %43.s[1] \n" - "fmla v20.4s, v17.4s, %44.s[1] \n" - "fmla v21.4s, v17.4s, %45.s[1] \n" - - "prfm pldl1keep, [%15, #128] \n" - "ld1 {v16.4s}, [%15], #16 \n" - - "fmla v22.4s, v17.4s, %46.s[1] \n" - "fmla v23.4s, v17.4s, %47.s[1] \n" - "fmla v24.4s, v17.4s, %48.s[1] \n" - "fmla v25.4s, v17.4s, %49.s[1] \n" - - "fmla v18.4s, v16.4s, %42.s[2] \n" - "fmla v19.4s, v16.4s, %43.s[2] \n" - "fmla v20.4s, v16.4s, %44.s[2] \n" - "fmla v21.4s, v16.4s, %45.s[2] \n" - - "prfm pldl1keep, [%16, #128] \n" - "ld1 {v17.4s}, [%16], #16 \n" - - "fmla v22.4s, v16.4s, %46.s[2] \n" - "fmla v23.4s, v16.4s, %47.s[2] \n" - "fmla v24.4s, v16.4s, %48.s[2] \n" - "fmla v25.4s, v16.4s, %49.s[2] \n" - - "fmla v18.4s, v17.4s, %42.s[3] \n" - "fmla v19.4s, v17.4s, %43.s[3] \n" - "fmla v20.4s, v17.4s, %44.s[3] \n" - "fmla v21.4s, v17.4s, %45.s[3] \n" - - "st1 {v18.4s}, [%1], #16 \n" - "st1 {v19.4s}, [%2], #16 \n" - "st1 {v20.4s}, [%3], #16 \n" - "st1 {v21.4s}, [%4], #16 \n" - - "fmla v22.4s, v17.4s, %46.s[3] \n" - "fmla v23.4s, v17.4s, %47.s[3] \n" - "fmla v24.4s, v17.4s, %48.s[3] \n" - "fmla v25.4s, v17.4s, %49.s[3] \n" - - "st1 {v22.4s}, [%5], #16 \n" - "st1 {v23.4s}, [%6], #16 \n" - "st1 {v24.4s}, [%7], #16 \n" - "st1 {v25.4s}, [%8], #16 \n" - - "prfm pldl1keep, [%9, #128] \n" - "ld1 {v16.4s}, [%9], #16 \n" - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v18.4s}, [%1] \n" - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v19.4s}, [%2] \n" - - "subs %w0, %w0, #1 \n" - "bne 0b \n" - : "=r"(nw), // 0 - "=r"(c_ptr0), // 1 - "=r"(c_ptr1), // 2 - "=r"(c_ptr2), // 3 - "=r"(c_ptr3), // 4 - "=r"(c_ptr4), // 5 - "=r"(c_ptr5), // 6 - "=r"(c_ptr6), // 7 - "=r"(c_ptr7), // 8 - "=r"(b_ptr0), // 9 - "=r"(b_ptr1), // 10 - "=r"(b_ptr2), // 11 - "=r"(b_ptr3), // 12 - "=r"(b_ptr4), // 13 - "=r"(b_ptr5), // 14 - "=r"(b_ptr6), // 15 - "=r"(b_ptr7) // 16 - : "0"(nw), // 17 - "1"(c_ptr0), // 18 - "2"(c_ptr1), // 19 - "3"(c_ptr2), // 20 - "4"(c_ptr3), // 21 - "5"(c_ptr4), // 22 - "6"(c_ptr5), // 23 - "7"(c_ptr6), // 24 - "8"(c_ptr7), // 25 - "9"(b_ptr0), // 26 - "10"(b_ptr1), // 27 - "11"(b_ptr2), // 28 - "12"(b_ptr3), // 29 - "13"(b_ptr4), // 30 - "14"(b_ptr5), // 31 - "15"(b_ptr6), // 32 - "16"(b_ptr7), // 33 - "w"(a0), // 34 - "w"(a2), // 35 - "w"(a4), // 36 - "w"(a6), // 37 - "w"(a8), // 38 - "w"(a10), // 39 - "w"(a12), // 40 - "w"(a14), // 41 - "w"(a1), // 42 - "w"(a3), // 43 - "w"(a5), // 44 - "w"(a7), // 45 - "w"(a9), // 46 - "w"(a11), // 47 - "w"(a13), // 48 - "w"(a15) // 49 - : "cc", "memory", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25" - ); + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v16.4s}, [%9], #16 \n" + + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v18.4s}, [%1] \n" + + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v19.4s}, [%2] \n" + + "0: \n" + + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v20.4s}, [%3] \n" + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v21.4s}, [%4] \n" + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v22.4s}, [%5] \n" + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v23.4s}, [%6] \n" + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v24.4s}, [%7] \n" + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v25.4s}, [%8] \n" + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v17.4s}, [%10], #16 \n" + + "fmla v18.4s, v16.4s, %34.s[0] \n" + "fmla v19.4s, v16.4s, %35.s[0] \n" + "fmla v20.4s, v16.4s, %36.s[0] \n" + "fmla v21.4s, v16.4s, %37.s[0] \n" + + "fmla v22.4s, v16.4s, %38.s[0] \n" + "fmla v23.4s, v16.4s, %39.s[0] \n" + "fmla v24.4s, v16.4s, %40.s[0] \n" + "fmla v25.4s, v16.4s, %41.s[0] \n" + + "fmla v18.4s, v17.4s, %34.s[1] \n" + "fmla v19.4s, v17.4s, %35.s[1] \n" + "fmla v20.4s, v17.4s, %36.s[1] \n" + "fmla v21.4s, v17.4s, %37.s[1] \n" + + "prfm pldl1keep, [%11, #128] \n" + "ld1 {v16.4s}, [%11], #16 \n" + + "fmla v22.4s, v17.4s, %38.s[1] \n" + "fmla v23.4s, v17.4s, %39.s[1] \n" + "fmla v24.4s, v17.4s, %40.s[1] \n" + "fmla v25.4s, v17.4s, %41.s[1] \n" + + "fmla v18.4s, v16.4s, %34.s[2] \n" + "fmla v19.4s, v16.4s, %35.s[2] \n" + "fmla v20.4s, v16.4s, %36.s[2] \n" + "fmla v21.4s, v16.4s, %37.s[2] \n" + + "prfm pldl1keep, [%12, #128] \n" + "ld1 {v17.4s}, [%12], #16 \n" + + "fmla v22.4s, v16.4s, %38.s[2] \n" + "fmla v23.4s, v16.4s, %39.s[2] \n" + "fmla v24.4s, v16.4s, %40.s[2] \n" + "fmla v25.4s, v16.4s, %41.s[2] \n" + + "fmla v18.4s, v17.4s, %34.s[3] \n" + "fmla v19.4s, v17.4s, %35.s[3] \n" + "fmla v20.4s, v17.4s, %36.s[3] \n" + "fmla v21.4s, v17.4s, %37.s[3] \n" + + "prfm pldl1keep, [%13, #128] \n" + "ld1 {v16.4s}, [%13], #16 \n" + + "fmla v22.4s, v17.4s, %38.s[3] \n" + "fmla v23.4s, v17.4s, %39.s[3] \n" + "fmla v24.4s, v17.4s, %40.s[3] \n" + "fmla v25.4s, v17.4s, %41.s[3] \n" + + "fmla v18.4s, v16.4s, %42.s[0] \n" + "fmla v19.4s, v16.4s, %43.s[0] \n" + "fmla v20.4s, v16.4s, %44.s[0] \n" + "fmla v21.4s, v16.4s, %45.s[0] \n" + + "prfm pldl1keep, [%14, #128] \n" + "ld1 {v17.4s}, [%14], #16 \n" + + "fmla v22.4s, v16.4s, %46.s[0] \n" + "fmla v23.4s, v16.4s, %47.s[0] \n" + "fmla v24.4s, v16.4s, %48.s[0] \n" + "fmla v25.4s, v16.4s, %49.s[0] \n" + + "fmla v18.4s, v17.4s, %42.s[1] \n" + "fmla v19.4s, v17.4s, %43.s[1] \n" + "fmla v20.4s, v17.4s, %44.s[1] \n" + "fmla v21.4s, v17.4s, %45.s[1] \n" + + "prfm pldl1keep, [%15, #128] \n" + "ld1 {v16.4s}, [%15], #16 \n" + + "fmla v22.4s, v17.4s, %46.s[1] \n" + "fmla v23.4s, v17.4s, %47.s[1] \n" + "fmla v24.4s, v17.4s, %48.s[1] \n" + "fmla v25.4s, v17.4s, %49.s[1] \n" + + "fmla v18.4s, v16.4s, %42.s[2] \n" + "fmla v19.4s, v16.4s, %43.s[2] \n" + "fmla v20.4s, v16.4s, %44.s[2] \n" + "fmla v21.4s, v16.4s, %45.s[2] \n" + + "prfm pldl1keep, [%16, #128] \n" + "ld1 {v17.4s}, [%16], #16 \n" + + "fmla v22.4s, v16.4s, %46.s[2] \n" + "fmla v23.4s, v16.4s, %47.s[2] \n" + "fmla v24.4s, v16.4s, %48.s[2] \n" + "fmla v25.4s, v16.4s, %49.s[2] \n" + + "fmla v18.4s, v17.4s, %42.s[3] \n" + "fmla v19.4s, v17.4s, %43.s[3] \n" + "fmla v20.4s, v17.4s, %44.s[3] \n" + "fmla v21.4s, v17.4s, %45.s[3] \n" + + "st1 {v18.4s}, [%1], #16 \n" + "st1 {v19.4s}, [%2], #16 \n" + "st1 {v20.4s}, [%3], #16 \n" + "st1 {v21.4s}, [%4], #16 \n" + + "fmla v22.4s, v17.4s, %46.s[3] \n" + "fmla v23.4s, v17.4s, %47.s[3] \n" + "fmla v24.4s, v17.4s, %48.s[3] \n" + "fmla v25.4s, v17.4s, %49.s[3] \n" + + "st1 {v22.4s}, [%5], #16 \n" + "st1 {v23.4s}, [%6], #16 \n" + "st1 {v24.4s}, [%7], #16 \n" + "st1 {v25.4s}, [%8], #16 \n" + + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v16.4s}, [%9], #16 \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v18.4s}, [%1] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v19.4s}, [%2] \n" + + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nw), // 0 + "=r"(c_ptr0), // 1 + "=r"(c_ptr1), // 2 + "=r"(c_ptr2), // 3 + "=r"(c_ptr3), // 4 + "=r"(c_ptr4), // 5 + "=r"(c_ptr5), // 6 + "=r"(c_ptr6), // 7 + "=r"(c_ptr7), // 8 + "=r"(b_ptr0), // 9 + "=r"(b_ptr1), // 10 + "=r"(b_ptr2), // 11 + "=r"(b_ptr3), // 12 + "=r"(b_ptr4), // 13 + "=r"(b_ptr5), // 14 + "=r"(b_ptr6), // 15 + "=r"(b_ptr7) // 16 + : "0"(nw), // 17 + "1"(c_ptr0), // 18 + "2"(c_ptr1), // 19 + "3"(c_ptr2), // 20 + "4"(c_ptr3), // 21 + "5"(c_ptr4), // 22 + "6"(c_ptr5), // 23 + "7"(c_ptr6), // 24 + "8"(c_ptr7), // 25 + "9"(b_ptr0), // 26 + "10"(b_ptr1), // 27 + "11"(b_ptr2), // 28 + "12"(b_ptr3), // 29 + "13"(b_ptr4), // 30 + "14"(b_ptr5), // 31 + "15"(b_ptr6), // 32 + "16"(b_ptr7), // 33 + "w"(a0), // 34 + "w"(a2), // 35 + "w"(a4), // 36 + "w"(a6), // 37 + "w"(a8), // 38 + "w"(a10), // 39 + "w"(a12), // 40 + "w"(a14), // 41 + "w"(a1), // 42 + "w"(a3), // 43 + "w"(a5), // 44 + "w"(a7), // 45 + "w"(a9), // 46 + "w"(a11), // 47 + "w"(a13), // 48 + "w"(a15) // 49 + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25"); w = (width >> 2) << 2; } -#else // gcc +#else // gcc || armv7a for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); } -#endif // clang +#endif // clang && armv8a if (w < width) { - const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); @@ -411,154 +839,37 @@ inline void GemmTile(const float *A, const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + k * stride_w; float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, - b_ptr, - 8, - K - k, - width, - stride_k, - stride_w, - c_ptr); + GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_k, stride_w, c_ptr); } } if (h < height) { - // TODO(liyin): may use Gemm444 - const float *a_ptr = A + (h * stride_k); - const float *b_ptr = B; - float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, - b_ptr, - height - h, - K, - width, - stride_k, - stride_w, - c_ptr); - } -#else - -#if defined(MACE_ENABLE_NEON) // armv7 - w = (width >> 2) << 2; - for (h = 0; h + 3 < height; h += 4) { - for (k = 0; k + 3 < K; k += 4) { + index_t remain_h = height - h; + for (k = 0; k < K - 7; k += 8) { const float *a_ptr = A + (h * stride_k + k); - int nw = width >> 2; - if (nw > 0) { - // load A - float32x2_t a00, a01, a10, a11, a20, a21, a30, a31; - a00 = vld1_f32(a_ptr); - a01 = vld1_f32(a_ptr + 2); - a10 = vld1_f32(a_ptr + 1 * stride_k); - a11 = vld1_f32(a_ptr + 1 * stride_k + 2); - a20 = vld1_f32(a_ptr + 2 * stride_k); - a21 = vld1_f32(a_ptr + 2 * stride_k + 2); - a30 = vld1_f32(a_ptr + 3 * stride_k); - a31 = vld1_f32(a_ptr + 3 * stride_k + 2); - - const float *b_ptr0 = B + k * stride_w; - const float *b_ptr1 = B + (k + 1) * stride_w; - const float *b_ptr2 = B + (k + 2) * stride_w; - const float *b_ptr3 = B + (k + 3) * stride_w; - - float *c_ptr0 = C + h * stride_w; - float *c_ptr1 = C + (h + 1) * stride_w; - float *c_ptr2 = C + (h + 2) * stride_w; - float *c_ptr3 = C + (h + 3) * stride_w; - - // TODO(liyin): asm v7 prefetch and load optimization - while (nw--) { - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1, c2, c3; - - c0 = vld1q_f32(c_ptr0); - - b0 = vld1q_f32(b_ptr0); - b1 = vld1q_f32(b_ptr1); - b2 = vld1q_f32(b_ptr2); - b3 = vld1q_f32(b_ptr3); - - c1 = vld1q_f32(c_ptr1); - c2 = vld1q_f32(c_ptr2); - c3 = vld1q_f32(c_ptr3); - - c0 = vmlaq_lane_f32(c0, b0, a00, 0); - c0 = vmlaq_lane_f32(c0, b1, a00, 1); - c0 = vmlaq_lane_f32(c0, b2, a01, 0); - c0 = vmlaq_lane_f32(c0, b3, a01, 1); - - vst1q_f32(c_ptr0, c0); - - c1 = vmlaq_lane_f32(c1, b0, a10, 0); - c1 = vmlaq_lane_f32(c1, b1, a10, 1); - c1 = vmlaq_lane_f32(c1, b2, a11, 0); - c1 = vmlaq_lane_f32(c1, b3, a11, 1); - - vst1q_f32(c_ptr1, c1); - - c2 = vmlaq_lane_f32(c2, b0, a20, 0); - c2 = vmlaq_lane_f32(c2, b1, a20, 1); - c2 = vmlaq_lane_f32(c2, b2, a21, 0); - c2 = vmlaq_lane_f32(c2, b3, a21, 1); - - vst1q_f32(c_ptr2, c2); - - c3 = vmlaq_lane_f32(c3, b0, a30, 0); - c3 = vmlaq_lane_f32(c3, b1, a30, 1); - c3 = vmlaq_lane_f32(c3, b2, a31, 0); - c3 = vmlaq_lane_f32(c3, b3, a31, 1); - - vst1q_f32(c_ptr3, c3); - - b_ptr0 += 4; - b_ptr1 += 4; - b_ptr2 += 4; - b_ptr3 += 4; - - c_ptr0 += 4; - c_ptr1 += 4; - c_ptr2 += 4; - c_ptr3 += 4; - } + index_t w; + for (w = 0; w + 3 < width; w += 4) { + const float *b_ptr = B + (k * stride_w + w); + float *c_ptr = C + (h * stride_w + w); + GemmX84(a_ptr, b_ptr, stride_k, stride_w, c_ptr, remain_h); } if (w < width) { - const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, 4, 4, width - w, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_k, stride_w, + c_ptr); } } if (k < K) { const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + k * stride_w; float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, - b_ptr, - 4, - K - k, - width, - stride_k, - stride_w, + GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_k, stride_w, c_ptr); } } - if (h < height) { - const float *a_ptr = A + (h * stride_k); - const float *b_ptr = B; - float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, - b_ptr, - height - h, - K, - width, - stride_k, - stride_w, - c_ptr); - } -#else // cpu +#else GemmBlock(A, B, height, K, width, stride_k, stride_w, C); -#endif // armv7 - -#endif // aarch64 +#endif // MACE_ENABLE_NEON } } // namespace @@ -601,166 +912,163 @@ void Gemm(const float *A, const index_t ih_begin = bh * block_size; const index_t ih_end = - bh * block_size + (bh == block_tile_height - 1 && remain_height > 0 - ? remain_height : block_size); + bh * block_size + (bh == block_tile_height - 1 && remain_height > 0 + ? remain_height + : block_size); const index_t iw_begin = bw * block_size; const index_t iw_end = - bw * block_size - + (bw == block_tile_width - 1 && remain_width > 0 ? remain_width - : block_size); + bw * block_size + (bw == block_tile_width - 1 && remain_width > 0 + ? remain_width + : block_size); for (index_t bk = 0; bk < block_tile_k; ++bk) { const index_t ik_begin = bk * block_size; const index_t ik_end = - bk * block_size - + (bk == block_tile_k - 1 && remain_k > 0 ? remain_k - : block_size); + bk * block_size + + (bk == block_tile_k - 1 && remain_k > 0 ? remain_k : block_size); // inside block: // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k GemmTile(a_base + (ih_begin * K + ik_begin), - b_base + (ik_begin * width + iw_begin), - ih_end - ih_begin, - ik_end - ik_begin, - iw_end - iw_begin, - K, - width, + b_base + (ik_begin * width + iw_begin), ih_end - ih_begin, + ik_end - ik_begin, iw_end - iw_begin, K, width, c_base + (ih_begin * width + iw_begin)); } // bk - } // bw - } // bh - } // n + } // bw + } // bh + } // n } // A: height x K, B: K x width, C: height x width void GemmRef(const float *A, const float *B, + const index_t batch, const index_t height, const index_t K, const index_t width, float *C) { - memset(C, 0, sizeof(float) * height * width); - for (int i = 0; i < height; ++i) { - for (int j = 0; j < width; ++j) { - for (int k = 0; k < K; ++k) { - C[i * width + j] += A[i * K + k] * B[k * width + j]; + memset(C, 0, sizeof(float) * batch * height * width); + for (index_t b = 0; b < batch; ++b) { + for (index_t i = 0; i < height; ++i) { + for (index_t j = 0; j < width; ++j) { + for (index_t k = 0; k < K; ++k) { + C[(b * height + i) * width + j] += + A[(b * height + i) * K + k] * B[(b * K + k) * width + j]; + } } } } } - void GemvRef(const float *m_ptr, const float *v_ptr, const index_t batch, const index_t width, const index_t height, float *out_ptr) { - memset(out_ptr, 0, sizeof(float) * height * batch); + memset(out_ptr, 0, batch * height * sizeof(float)); +#pragma omp parallel for collapse(2) for (int b = 0; b < batch; ++b) { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { - out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w]; + out_ptr[b * height + h] += v_ptr[b * width + w] * m_ptr[h * width + w]; } } } } -// M: height x width, Vin: width x 1, Vout: height x 1 +// TODO(liyin): batched gemv can be transformed to gemm (w/ transpose) void Gemv(const float *m_ptr, const float *v_ptr, const index_t batch, const index_t width, const index_t height, float *out_ptr) { -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) - index_t height_d4 = height >> 2; - index_t width_d4 = width >> 2; - index_t remain_w = width - (width_d4 << 2); - index_t remain_h = height - (height_d4 << 2); - +#if defined(MACE_ENABLE_NEON) +// TODO(liyin/wch): try height tiling = 8 +#pragma omp parallel for collapse(2) for (index_t b = 0; b < batch; ++b) { -#pragma omp parallel for - for (index_t h = 0; h < height_d4; ++h) { - const float *m_ptr0 = m_ptr + h * width * 4; - const float *m_ptr1 = m_ptr0 + width; - const float *m_ptr2 = m_ptr1 + width; - const float *m_ptr3 = m_ptr2 + width; - const float *v_ptr0 = v_ptr + b * width; - float *out_ptr0 = out_ptr + h * 4 + b * height; - - float32x4_t vm0, vm1, vm2, vm3; - float32x4_t vv; - - float32x4_t vsum0 = vdupq_n_f32(0.f); - float32x4_t vsum1 = vdupq_n_f32(0.f); - float32x4_t vsum2 = vdupq_n_f32(0.f); - float32x4_t vsum3 = vdupq_n_f32(0.f); - - for (index_t w = 0; w < width_d4; ++w) { - vm0 = vld1q_f32(m_ptr0); - vm1 = vld1q_f32(m_ptr1); - vm2 = vld1q_f32(m_ptr2); - vm3 = vld1q_f32(m_ptr3); - vv = vld1q_f32(v_ptr0); - - vsum0 = vmlaq_f32(vsum0, vm0, vv); - vsum1 = vmlaq_f32(vsum1, vm1, vv); - vsum2 = vmlaq_f32(vsum2, vm2, vv); - vsum3 = vmlaq_f32(vsum3, vm3, vv); - - m_ptr0 += 4; - m_ptr1 += 4; - m_ptr2 += 4; - m_ptr3 += 4; - v_ptr0 += 4; - } - float sum0 = vaddvq_f32(vsum0); - float sum1 = vaddvq_f32(vsum1); - float sum2 = vaddvq_f32(vsum2); - float sum3 = vaddvq_f32(vsum3); - - // handle remaining w - for (index_t w = 0; w < remain_w; ++w) { - sum0 += m_ptr0[0] * v_ptr0[0]; - sum1 += m_ptr1[0] * v_ptr0[0]; - sum2 += m_ptr2[0] * v_ptr0[0]; - sum3 += m_ptr3[0] * v_ptr0[0]; - m_ptr0++; - m_ptr1++; - m_ptr2++; - m_ptr3++; - v_ptr0++; - } - *out_ptr0++ = sum0; - *out_ptr0++ = sum1; - *out_ptr0++ = sum2; - *out_ptr0++ = sum3; - } - - // handle remaining h - index_t remain_start_height = height_d4 << 2; -#pragma omp parallel for - for (index_t h = 0; h < remain_h; ++h) { - float32x4_t vsum0 = vdupq_n_f32(0.f); - const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; - const float *v_ptr0 = v_ptr + b * width; - for (index_t w = 0; w < width_d4; ++w) { - float32x4_t vm = vld1q_f32(m_ptr0); - float32x4_t vv = vld1q_f32(v_ptr0); - vsum0 = vmlaq_f32(vsum0, vm, vv); - m_ptr0 += 4; - v_ptr0 += 4; - } - float sum = vaddvq_f32(vsum0); - for (index_t w = 0; w < remain_w; ++w) { - sum += m_ptr0[0] * v_ptr0[0]; - m_ptr0++; - v_ptr0++; - } - out_ptr[remain_start_height + h + b * height] = sum; - } - } + for (index_t h = 0; h < height; h += 4) { + if (h + 3 < height) { + const float *m_ptr0 = m_ptr + h * width; + const float *m_ptr1 = m_ptr0 + width; + const float *m_ptr2 = m_ptr1 + width; + const float *m_ptr3 = m_ptr2 + width; + const float *v_ptr0 = v_ptr + b * width; + float *out_ptr0 = out_ptr + b * height + h; + + float32x4_t vm0, vm1, vm2, vm3; + float32x4_t vv; + + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + float32x4_t vsum2 = vdupq_n_f32(0.f); + float32x4_t vsum3 = vdupq_n_f32(0.f); + + index_t w; + for (w = 0; w + 3 < width; w += 4) { + vm0 = vld1q_f32(m_ptr0); + vm1 = vld1q_f32(m_ptr1); + vm2 = vld1q_f32(m_ptr2); + vm3 = vld1q_f32(m_ptr3); + vv = vld1q_f32(v_ptr0); + + vsum0 = vmlaq_f32(vsum0, vm0, vv); + vsum1 = vmlaq_f32(vsum1, vm1, vv); + vsum2 = vmlaq_f32(vsum2, vm2, vv); + vsum3 = vmlaq_f32(vsum3, vm3, vv); + + m_ptr0 += 4; + m_ptr1 += 4; + m_ptr2 += 4; + m_ptr3 += 4; + v_ptr0 += 4; + } + float sum0 = vaddvq_f32(vsum0); + float sum1 = vaddvq_f32(vsum1); + float sum2 = vaddvq_f32(vsum2); + float sum3 = vaddvq_f32(vsum3); + + // handle remaining w + for (; w < width; ++w) { + sum0 += m_ptr0[0] * v_ptr0[0]; + sum1 += m_ptr1[0] * v_ptr0[0]; + sum2 += m_ptr2[0] * v_ptr0[0]; + sum3 += m_ptr3[0] * v_ptr0[0]; + m_ptr0++; + m_ptr1++; + m_ptr2++; + m_ptr3++; + v_ptr0++; + } + *out_ptr0++ = sum0; + *out_ptr0++ = sum1; + *out_ptr0++ = sum2; + *out_ptr0++ = sum3; + } else { + for (index_t hh = h; hh < height; ++hh) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + const float *m_ptr0 = m_ptr + hh * width; + const float *v_ptr0 = v_ptr + b * width; + index_t w; + for (w = 0; w + 3 < width; w += 4) { + float32x4_t vm = vld1q_f32(m_ptr0); + float32x4_t vv = vld1q_f32(v_ptr0); + vsum0 = vmlaq_f32(vsum0, vm, vv); + m_ptr0 += 4; + v_ptr0 += 4; + } + float sum = vaddvq_f32(vsum0); + for (; w < width; ++w) { + sum += m_ptr0[0] * v_ptr0[0]; + m_ptr0++; + v_ptr0++; + } + out_ptr[b * height + hh] = sum; + } + } // if + } // h + } // b #else GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr); #endif diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index e1fcfad6a51b9e4611e538709818cf5126311a5a..9a7ce77bcab2138bc98bcf9863cf9e2e146f4637 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -34,6 +34,7 @@ void Gemm(const float *A, void GemmRef(const float *A, const float *B, + const index_t batch, const index_t height, const index_t K, const index_t width, diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 8400ca857dade6212f96ecfd4ff17b48316bbc81..90a792ef236455617107b311c7b2eea7c7d56aa0 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -21,62 +21,98 @@ namespace mace { -TEST(GEMMTest, gemm) { - index_t N = 17; - index_t M = 33; - index_t K = 64; - std::unique_ptr A(new float[N * K]); - std::unique_ptr B(new float[K * M]); - std::unique_ptr C(new float[N * M]); - std::unique_ptr C_ref(new float[N * M]); +namespace { + +void GemmTest(index_t batch, index_t N, index_t K, index_t M) { + std::unique_ptr A(new float[batch * N * K]); + std::unique_ptr B(new float[batch * K * M]); + std::unique_ptr C(new float[batch * N * M]); + std::unique_ptr C_ref(new float[batch * N * M]); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); - std::generate(A.get(), A.get() + N * K, + std::generate(A.get(), A.get() + batch * N * K, [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + K * M, + std::generate(B.get(), B.get() + batch * K * M, [&gen, &nd] { return nd(gen); }); - kernels::Gemm(A.get(), B.get(), 1, N, K, M, C.get()); - kernels::GemmRef(A.get(), B.get(), N, K, M, C_ref.get()); + kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get()); + kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get()); - for (int i = 0; i < N * M; ++i) { + for (int i = 0; i < batch * N * M; ++i) { EXPECT_NEAR(C_ref[i], C[i], 0.1); } } -TEST(GEMMTest, gemv) { - index_t N = 17; - index_t K = 63; - std::unique_ptr A(new float[N * K]); - std::unique_ptr B(new float[K]); - std::unique_ptr C(new float[N]); - std::unique_ptr C_ref(new float[N]); +void GemvTest(index_t batch, index_t N, index_t M) { + std::unique_ptr A(new float[N * M]); + std::unique_ptr B(new float[batch * M]); + std::unique_ptr C(new float[batch * N]); + std::unique_ptr C_ref(new float[batch * N]); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); - std::generate(A.get(), A.get() + N * K, + std::generate(A.get(), A.get() + N * M, [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + K, + std::generate(B.get(), B.get() + batch * M, [&gen, &nd] { return nd(gen); }); - kernels::Gemv(A.get(), B.get(), 1, K, N, C.get()); - kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get()); + kernels::Gemv(A.get(), B.get(), batch, M, N, C.get()); + kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get()); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < batch * N; ++i) { EXPECT_NEAR(C_ref[i], C[i], 0.1); } } +} // namespace + +TEST(GEMMTest, AlignedWithoutBatch) { + GemmTest(1, 1, 64, 128); + GemmTest(1, 2, 64, 128); + GemmTest(1, 3, 64, 128); + GemmTest(1, 4, 64, 128); + GemmTest(1, 5, 64, 128); + GemmTest(1, 6, 64, 128); + GemmTest(1, 7, 64, 128); + GemmTest(1, 17, 64, 128); +} + +TEST(GEMMTest, UnalignedWithoutBatch) { + GemmTest(1, 1, 63, 127); + GemmTest(1, 2, 63, 127); + GemmTest(1, 3, 63, 127); + GemmTest(1, 4, 63, 127); + GemmTest(1, 5, 63, 127); + GemmTest(1, 6, 63, 127); + GemmTest(1, 7, 63, 127); + GemmTest(1, 17, 63, 127); +} + +TEST(GEMMTest, UnalignedWithBatch) { + GemmTest(3, 1, 63, 127); + GemmTest(3, 2, 63, 127); + GemmTest(3, 3, 63, 127); + GemmTest(3, 4, 63, 127); + GemmTest(3, 5, 63, 127); + GemmTest(3, 6, 63, 127); + GemmTest(3, 7, 63, 127); + GemmTest(3, 17, 63, 127); +} + +TEST(GEMMTest, gemv) { + GemvTest(1, 17, 63); + GemvTest(3, 17, 63); +} } // namespace mace