diff --git a/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp index 391cc9f807f6ca73080341cbb18fd68195ab7f6e..7222eaacf023c11a3603f3ab6884188f731495b1 100644 --- a/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp +++ b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp @@ -85,6 +85,18 @@ void hgemv_naive_n( } } // namespace +#if defined(__aarch64__) +#define VFMAQ_N_F16(a, b, n) vfmaq_n_f16(a, b, n) +#else +#define VFMAQ_N_F16(a, b, n) vaddq_f16(a, vmulq_n_f16(b, n)) +#endif + +#if defined(__aarch64__) +#define VFMA_N_F16(a, b, n) vfma_n_f16(a, b, n) +#else +#define VFMA_N_F16(a, b, n) vadd_f16(a, vmul_n_f16(b, n)) +#endif + void megdnn::arm_common::gemv_like( const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { @@ -98,33 +110,30 @@ void megdnn::arm_common::gemv_like( memset(C + m * Cstride, 0, 4 * sizeof(__fp16) * N); for (; k + 4 <= K; k += 4) { size_t n = 0; + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1], + a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3]; + __fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1], + a12 = A[(m + 1) * Astride + k + 2], + a13 = A[(m + 1) * Astride + k + 3]; + __fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1], + a22 = A[(m + 2) * Astride + k + 2], + a23 = A[(m + 2) * Astride + k + 3]; + __fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1], + a32 = A[(m + 3) * Astride + k + 2], + a33 = A[(m + 3) * Astride + k + 3]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, - a30, a31, a32, a33; float16x8_t b0, b1, b2, b3; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) - UNROLL_OUT(loadA2, 4) - UNROLL_OUT(loadA3, 4) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row2, 4) @@ -138,32 +147,18 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, - a30, a31, a32, a33; float16x4_t b0, b1, b2, b3; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) - UNROLL_OUT(loadA2, 4) - UNROLL_OUT(loadA3, 4) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row2, 4) @@ -177,8 +172,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, a30, - a31, a32, a33; __fp16 b0, b1, b2, b3; __fp16 c0, c1, c2, c3; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -187,18 +180,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 4) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[m * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; -#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; -#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) - UNROLL_OUT(loadA2, 4) - UNROLL_OUT(loadA3, 4) -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3; @@ -210,32 +191,23 @@ void megdnn::arm_common::gemv_like( } for (; k + 2 <= K; k += 2) { size_t n = 0; + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1]; + __fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1]; + __fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1]; + __fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a10, a11, a20, a21, a30, a31; float16x8_t b0, b1; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) - UNROLL_OUT(loadA2, 2) - UNROLL_OUT(loadA3, 2) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row2, 2) @@ -249,31 +221,18 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a10, a11, a20, a21, a30, a31; float16x4_t b0, b1; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) - UNROLL_OUT(loadA2, 2) - UNROLL_OUT(loadA3, 2) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row2, 2) @@ -287,7 +246,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a01, a10, a11, a20, a21, a30, a31; __fp16 b0, b1; __fp16 c0, c1, c2, c3; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -296,18 +254,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 2) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; -#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; -#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) - UNROLL_OUT(loadA2, 2) - UNROLL_OUT(loadA3, 2) -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 c0 += a00 * b0 + a01 * b1; c1 += a10 * b0 + a11 * b1; c2 += a20 * b0 + a21 * b1; @@ -319,32 +265,23 @@ void megdnn::arm_common::gemv_like( } for (; k < K; k += 1) { size_t n = 0; + __fp16 a00 = A[m * Astride + k]; + __fp16 a10 = A[(m + 1) * Astride + k]; + __fp16 a20 = A[(m + 2) * Astride + k]; + __fp16 a30 = A[(m + 3) * Astride + k]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a10, a20, a30; float16x8_t b0; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) - UNROLL_OUT(loadA2, 1) - UNROLL_OUT(loadA3, 1) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row2, 1) @@ -358,31 +295,18 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a10, a20, a30; float16x4_t b0; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); -#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); -#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 4) UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) - UNROLL_OUT(loadA2, 1) - UNROLL_OUT(loadA3, 1) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); -#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i); -#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); +#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i); +#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i); UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row2, 1) @@ -396,7 +320,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a10, a20, a30; __fp16 b0; __fp16 c0, c1, c2, c3; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -405,18 +328,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 1) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; -#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i]; -#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i]; - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) - UNROLL_OUT(loadA2, 1) - UNROLL_OUT(loadA3, 1) -#undef loadA0 -#undef loadA1 -#undef loadA2 -#undef loadA3 c0 = c0 + a00 * b0; c1 = c1 + a10 * b0; c2 = c2 + a20 * b0; @@ -432,24 +343,22 @@ void megdnn::arm_common::gemv_like( memset(C + m * Cstride, 0, 2 * sizeof(__fp16) * N); for (; k + 4 <= K; k += 4) { size_t n = 0; + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1], + a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3]; + __fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1], + a12 = A[(m + 1) * Astride + k + 2], + a13 = A[(m + 1) * Astride + k + 3]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a02, a03, a10, a11, a12, a13; float16x8_t b0, b1, b2, b3; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row1, 4) #undef calculate_row0 @@ -459,23 +368,16 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a02, a03, a10, a11, a12, a13; float16x4_t b0, b1, b2, b3; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row1, 4) #undef calculate_row0 @@ -485,7 +387,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a01, a02, a03, a10, a11, a12, a13; __fp16 b0, b1, b2, b3; __fp16 c0, c1; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -494,12 +395,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 4) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[m * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; - UNROLL_OUT(loadA0, 4) - UNROLL_OUT(loadA1, 4) -#undef loadA0 -#undef loadA1 c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; #define vstore(i) C[(m + i) * Cstride + n] = c##i; @@ -509,24 +404,19 @@ void megdnn::arm_common::gemv_like( } for (; k + 2 <= K; k += 2) { size_t n = 0; + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1]; + __fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a10, a11; float16x8_t b0, b1; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row1, 2) #undef calculate_row0 @@ -536,23 +426,16 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a10, a11; float16x4_t b0, b1; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row1, 2) #undef calculate_row0 @@ -562,7 +445,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a01, a10, a11; __fp16 b0, b1; __fp16 c0, c1; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -571,12 +453,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 2) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; - UNROLL_OUT(loadA0, 2) - UNROLL_OUT(loadA1, 2) -#undef loadA0 -#undef loadA1 c0 += a00 * b0 + a01 * b1; c1 += a10 * b0 + a11 * b1; #define vstore(i) C[(m + i) * Cstride + n] = c##i; @@ -586,24 +462,19 @@ void megdnn::arm_common::gemv_like( } for (; k < K; k += 1) { size_t n = 0; + __fp16 a00 = A[m * Astride + k]; + __fp16 a10 = A[(m + 1) * Astride + k]; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a10; float16x8_t b0; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row1, 1) #undef calculate_row0 @@ -613,23 +484,16 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a10; float16x4_t b0; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); -#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) #undef loadB #undef loadC -#undef loadA0 -#undef loadA1 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); -#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i); +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i); UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row1, 1) #undef calculate_row0 @@ -639,7 +503,6 @@ void megdnn::arm_common::gemv_like( #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a10; __fp16 b0; __fp16 c0, c1; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -648,12 +511,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 1) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; -#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i]; - UNROLL_OUT(loadA0, 1) - UNROLL_OUT(loadA1, 1) -#undef loadA0 -#undef loadA1 c0 = c0 + a00 * b0; c1 = c1 + a10 * b0; #define vstore(i) C[(m + i) * Cstride + n] = c##i; @@ -667,48 +524,61 @@ void megdnn::arm_common::gemv_like( memset(C + m * Cstride, 0, sizeof(__fp16) * N); for (; k + 4 <= K; k += 4) { size_t n = 0; - for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a02, a03; - float16x8_t b0, b1, b2, b3; - float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1], + a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3]; + { +#if !defined(__aarch64__) + float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01), + va02 = vdupq_n_f16(a02), va03 = vdupq_n_f16(a03); +#endif + for (; n + 8 <= N; n += 8) { + float16x8_t b0, b1, b2, b3; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 4) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 4) #undef calculate_row0 #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } - for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a02, a03; - float16x4_t b0, b1, b2, b3; - float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 4) - UNROLL_OUT(loadA0, 4) + { +#if !defined(__aarch64__) + float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01), + va02 = vdup_n_f16(a02), va03 = vdup_n_f16(a03); +#endif + for (; n + 4 <= N; n += 4) { + float16x4_t b0, b1, b2, b3; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 4) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 4) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 4) #undef calculate_row0 #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } for (; n < N; n += 1) { - __fp16 a00, a01, a02, a03; __fp16 b0, b1, b2, b3; __fp16 c0; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -717,9 +587,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 4) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[m * Astride + k + i]; - UNROLL_OUT(loadA0, 4) -#undef loadA0 c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; #define vstore(i) C[(m + i) * Cstride + n] = c##i; UNROLL_OUT(vstore, 1) @@ -727,49 +594,59 @@ void megdnn::arm_common::gemv_like( } } for (; k + 2 <= K; k += 2) { + __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1]; size_t n = 0; - for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01; - float16x8_t b0, b1; - float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) + { +#if !defined(__aarch64__) + float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01); +#endif + for (; n + 8 <= N; n += 8) { + float16x8_t b0, b1; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 2) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 2) #undef calculate_row0 #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } - for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01; - float16x4_t b0, b1; - float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 2) - UNROLL_OUT(loadA0, 2) + { +#if !defined(__aarch64__) + float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01); +#endif + for (; n + 4 <= N; n += 4) { + float16x4_t b0, b1; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 2) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 2) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 2) #undef calculate_row0 #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } for (; n < N; n += 1) { - __fp16 a00, a01; __fp16 b0, b1; __fp16 c0; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -778,9 +655,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 2) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; - UNROLL_OUT(loadA0, 2) -#undef loadA0 c0 += a00 * b0 + a01 * b1; #define vstore(i) C[(m + i) * Cstride + n] = c##i; UNROLL_OUT(vstore, 1) @@ -789,48 +663,58 @@ void megdnn::arm_common::gemv_like( } for (; k < K; k += 1) { size_t n = 0; - for (; n + 8 <= N; n += 8) { - float16x8_t a00; - float16x8_t b0; - float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) + __fp16 a00 = A[m * Astride + k]; + { +#if !defined(__aarch64__) + float16x8_t va00 = vdupq_n_f16(a00); +#endif + for (; n + 8 <= N; n += 8) { + float16x8_t b0; + float16x8_t c0; +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 1) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 1) #undef calculate_row0 #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } - for (; n + 4 <= N; n += 4) { - float16x4_t a00; - float16x4_t b0; - float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); -#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); - UNROLL_OUT(loadC, 1) - UNROLL_OUT(loadB, 1) - UNROLL_OUT(loadA0, 1) + { +#if !defined(__aarch64__) + float16x4_t va00 = vdup_n_f16(a00); +#endif + for (; n + 4 <= N; n += 4) { + float16x4_t b0; + float16x4_t c0; +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); + UNROLL_OUT(loadC, 1) + UNROLL_OUT(loadB, 1) #undef loadB #undef loadC -#undef loadA0 -#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); - UNROLL_OUT(calculate_row0, 1) +#if defined(__aarch64__) +#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i); +#else +#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i); +#endif + UNROLL_OUT(calculate_row0, 1) #undef calculate_row0 #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); - UNROLL_OUT(vstore, 1) + UNROLL_OUT(vstore, 1) #undef vstore + } } for (; n < N; n += 1) { - __fp16 a00; __fp16 b0; __fp16 c0; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -839,9 +723,6 @@ void megdnn::arm_common::gemv_like( UNROLL_OUT(loadB, 1) #undef loadB #undef loadC -#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i]; - UNROLL_OUT(loadA0, 1) -#undef loadA0 c0 = c0 + a00 * b0; #define vstore(i) C[(m + i) * Cstride + n] = c##i; UNROLL_OUT(vstore, 1) @@ -850,6 +731,10 @@ void megdnn::arm_common::gemv_like( } } } + +#undef VFMA_N_F16 +#undef VFMAQ_N_F16 + bool megdnn::arm_common::is_hgemv_preferred( bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/, size_t LDB, size_t /*LDC*/) {