提交 0612b573 编写于 作者: M Megvii Engine Team

perf(dnn/arm): optimize gevm by reducing access to memory of matrix A

GitOrigin-RevId: 89ed7bfd50114be4fc8c2c3283bf92883afc5283
上级 b622064a
...@@ -85,6 +85,18 @@ void hgemv_naive_n( ...@@ -85,6 +85,18 @@ void hgemv_naive_n(
} }
} // namespace } // namespace
#if defined(__aarch64__)
#define VFMAQ_N_F16(a, b, n) vfmaq_n_f16(a, b, n)
#else
#define VFMAQ_N_F16(a, b, n) vaddq_f16(a, vmulq_n_f16(b, n))
#endif
#if defined(__aarch64__)
#define VFMA_N_F16(a, b, n) vfma_n_f16(a, b, n)
#else
#define VFMA_N_F16(a, b, n) vadd_f16(a, vmul_n_f16(b, n))
#endif
void megdnn::arm_common::gemv_like( void megdnn::arm_common::gemv_like(
const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C, const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
...@@ -98,33 +110,30 @@ void megdnn::arm_common::gemv_like( ...@@ -98,33 +110,30 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, 4 * sizeof(__fp16) * N); memset(C + m * Cstride, 0, 4 * sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) { for (; k + 4 <= K; k += 4) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1],
a12 = A[(m + 1) * Astride + k + 2],
a13 = A[(m + 1) * Astride + k + 3];
__fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1],
a22 = A[(m + 2) * Astride + k + 2],
a23 = A[(m + 2) * Astride + k + 3];
__fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1],
a32 = A[(m + 3) * Astride + k + 2],
a33 = A[(m + 3) * Astride + k + 3];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23,
a30, a31, a32, a33;
float16x8_t b0, b1, b2, b3; float16x8_t b0, b1, b2, b3;
float16x8_t c0, c1, c2, c3; float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row1, 4)
UNROLL_OUT(calculate_row2, 4) UNROLL_OUT(calculate_row2, 4)
...@@ -138,32 +147,18 @@ void megdnn::arm_common::gemv_like( ...@@ -138,32 +147,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23,
a30, a31, a32, a33;
float16x4_t b0, b1, b2, b3; float16x4_t b0, b1, b2, b3;
float16x4_t c0, c1, c2, c3; float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row1, 4)
UNROLL_OUT(calculate_row2, 4) UNROLL_OUT(calculate_row2, 4)
...@@ -177,8 +172,6 @@ void megdnn::arm_common::gemv_like( ...@@ -177,8 +172,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, a30,
a31, a32, a33;
__fp16 b0, b1, b2, b3; __fp16 b0, b1, b2, b3;
__fp16 c0, c1, c2, c3; __fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -187,18 +180,6 @@ void megdnn::arm_common::gemv_like( ...@@ -187,18 +180,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3;
c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3; c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3;
...@@ -210,32 +191,23 @@ void megdnn::arm_common::gemv_like( ...@@ -210,32 +191,23 @@ void megdnn::arm_common::gemv_like(
} }
for (; k + 2 <= K; k += 2) { for (; k + 2 <= K; k += 2) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1];
__fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1];
__fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a10, a11, a20, a21, a30, a31;
float16x8_t b0, b1; float16x8_t b0, b1;
float16x8_t c0, c1, c2, c3; float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row1, 2)
UNROLL_OUT(calculate_row2, 2) UNROLL_OUT(calculate_row2, 2)
...@@ -249,31 +221,18 @@ void megdnn::arm_common::gemv_like( ...@@ -249,31 +221,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a10, a11, a20, a21, a30, a31;
float16x4_t b0, b1; float16x4_t b0, b1;
float16x4_t c0, c1, c2, c3; float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row1, 2)
UNROLL_OUT(calculate_row2, 2) UNROLL_OUT(calculate_row2, 2)
...@@ -287,7 +246,6 @@ void megdnn::arm_common::gemv_like( ...@@ -287,7 +246,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01, a10, a11, a20, a21, a30, a31;
__fp16 b0, b1; __fp16 b0, b1;
__fp16 c0, c1, c2, c3; __fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -296,18 +254,6 @@ void megdnn::arm_common::gemv_like( ...@@ -296,18 +254,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 += a00 * b0 + a01 * b1; c0 += a00 * b0 + a01 * b1;
c1 += a10 * b0 + a11 * b1; c1 += a10 * b0 + a11 * b1;
c2 += a20 * b0 + a21 * b1; c2 += a20 * b0 + a21 * b1;
...@@ -319,32 +265,23 @@ void megdnn::arm_common::gemv_like( ...@@ -319,32 +265,23 @@ void megdnn::arm_common::gemv_like(
} }
for (; k < K; k += 1) { for (; k < K; k += 1) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k];
__fp16 a10 = A[(m + 1) * Astride + k];
__fp16 a20 = A[(m + 2) * Astride + k];
__fp16 a30 = A[(m + 3) * Astride + k];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a10, a20, a30;
float16x8_t b0; float16x8_t b0;
float16x8_t c0, c1, c2, c3; float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row1, 1)
UNROLL_OUT(calculate_row2, 1) UNROLL_OUT(calculate_row2, 1)
...@@ -358,31 +295,18 @@ void megdnn::arm_common::gemv_like( ...@@ -358,31 +295,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a10, a20, a30;
float16x4_t b0; float16x4_t b0;
float16x4_t c0, c1, c2, c3; float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4) UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#undef loadA2 #define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#undef loadA3 #define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row1, 1)
UNROLL_OUT(calculate_row2, 1) UNROLL_OUT(calculate_row2, 1)
...@@ -396,7 +320,6 @@ void megdnn::arm_common::gemv_like( ...@@ -396,7 +320,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a10, a20, a30;
__fp16 b0; __fp16 b0;
__fp16 c0, c1, c2, c3; __fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -405,18 +328,6 @@ void megdnn::arm_common::gemv_like( ...@@ -405,18 +328,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 = c0 + a00 * b0; c0 = c0 + a00 * b0;
c1 = c1 + a10 * b0; c1 = c1 + a10 * b0;
c2 = c2 + a20 * b0; c2 = c2 + a20 * b0;
...@@ -432,24 +343,22 @@ void megdnn::arm_common::gemv_like( ...@@ -432,24 +343,22 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, 2 * sizeof(__fp16) * N); memset(C + m * Cstride, 0, 2 * sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) { for (; k + 4 <= K; k += 4) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1],
a12 = A[(m + 1) * Astride + k + 2],
a13 = A[(m + 1) * Astride + k + 3];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a02, a03, a10, a11, a12, a13;
float16x8_t b0, b1, b2, b3; float16x8_t b0, b1, b2, b3;
float16x8_t c0, c1; float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row1, 4)
#undef calculate_row0 #undef calculate_row0
...@@ -459,23 +368,16 @@ void megdnn::arm_common::gemv_like( ...@@ -459,23 +368,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a02, a03, a10, a11, a12, a13;
float16x4_t b0, b1, b2, b3; float16x4_t b0, b1, b2, b3;
float16x4_t c0, c1; float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 4) UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4) UNROLL_OUT(calculate_row1, 4)
#undef calculate_row0 #undef calculate_row0
...@@ -485,7 +387,6 @@ void megdnn::arm_common::gemv_like( ...@@ -485,7 +387,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03, a10, a11, a12, a13;
__fp16 b0, b1, b2, b3; __fp16 b0, b1, b2, b3;
__fp16 c0, c1; __fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -494,12 +395,6 @@ void megdnn::arm_common::gemv_like( ...@@ -494,12 +395,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadA0
#undef loadA1
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3; c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
...@@ -509,24 +404,19 @@ void megdnn::arm_common::gemv_like( ...@@ -509,24 +404,19 @@ void megdnn::arm_common::gemv_like(
} }
for (; k + 2 <= K; k += 2) { for (; k + 2 <= K; k += 2) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a10, a11;
float16x8_t b0, b1; float16x8_t b0, b1;
float16x8_t c0, c1; float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row1, 2)
#undef calculate_row0 #undef calculate_row0
...@@ -536,23 +426,16 @@ void megdnn::arm_common::gemv_like( ...@@ -536,23 +426,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a10, a11;
float16x4_t b0, b1; float16x4_t b0, b1;
float16x4_t c0, c1; float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 2) UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2) UNROLL_OUT(calculate_row1, 2)
#undef calculate_row0 #undef calculate_row0
...@@ -562,7 +445,6 @@ void megdnn::arm_common::gemv_like( ...@@ -562,7 +445,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01, a10, a11;
__fp16 b0, b1; __fp16 b0, b1;
__fp16 c0, c1; __fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -571,12 +453,6 @@ void megdnn::arm_common::gemv_like( ...@@ -571,12 +453,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadA0
#undef loadA1
c0 += a00 * b0 + a01 * b1; c0 += a00 * b0 + a01 * b1;
c1 += a10 * b0 + a11 * b1; c1 += a10 * b0 + a11 * b1;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
...@@ -586,24 +462,19 @@ void megdnn::arm_common::gemv_like( ...@@ -586,24 +462,19 @@ void megdnn::arm_common::gemv_like(
} }
for (; k < K; k += 1) { for (; k < K; k += 1) {
size_t n = 0; size_t n = 0;
__fp16 a00 = A[m * Astride + k];
__fp16 a10 = A[(m + 1) * Astride + k];
for (; n + 8 <= N; n += 8) { for (; n + 8 <= N; n += 8) {
float16x8_t a00, a10;
float16x8_t b0; float16x8_t b0;
float16x8_t c0, c1; float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row1, 1)
#undef calculate_row0 #undef calculate_row0
...@@ -613,23 +484,16 @@ void megdnn::arm_common::gemv_like( ...@@ -613,23 +484,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n + 4 <= N; n += 4) { for (; n + 4 <= N; n += 4) {
float16x4_t a00, a10;
float16x4_t b0; float16x4_t b0;
float16x4_t c0, c1; float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2) UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#undef loadA1 #define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 1) UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1) UNROLL_OUT(calculate_row1, 1)
#undef calculate_row0 #undef calculate_row0
...@@ -639,7 +503,6 @@ void megdnn::arm_common::gemv_like( ...@@ -639,7 +503,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore #undef vstore
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a10;
__fp16 b0; __fp16 b0;
__fp16 c0, c1; __fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -648,12 +511,6 @@ void megdnn::arm_common::gemv_like( ...@@ -648,12 +511,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadA0
#undef loadA1
c0 = c0 + a00 * b0; c0 = c0 + a00 * b0;
c1 = c1 + a10 * b0; c1 = c1 + a10 * b0;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
...@@ -667,48 +524,61 @@ void megdnn::arm_common::gemv_like( ...@@ -667,48 +524,61 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, sizeof(__fp16) * N); memset(C + m * Cstride, 0, sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) { for (; k + 4 <= K; k += 4) {
size_t n = 0; size_t n = 0;
for (; n + 8 <= N; n += 8) { __fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
float16x8_t a00, a01, a02, a03; a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
float16x8_t b0, b1, b2, b3; {
float16x8_t c0; #if !defined(__aarch64__)
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01),
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); va02 = vdupq_n_f16(a02), va03 = vdupq_n_f16(a03);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #endif
UNROLL_OUT(loadC, 1) for (; n + 8 <= N; n += 8) {
UNROLL_OUT(loadB, 4) float16x8_t b0, b1, b2, b3;
UNROLL_OUT(loadA0, 4) float16x8_t c0;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 4) #else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 4)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n + 4 <= N; n += 4) { {
float16x4_t a00, a01, a02, a03; #if !defined(__aarch64__)
float16x4_t b0, b1, b2, b3; float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01),
float16x4_t c0; va02 = vdup_n_f16(a02), va03 = vdup_n_f16(a03);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); #endif
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); for (; n + 4 <= N; n += 4) {
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); float16x4_t b0, b1, b2, b3;
UNROLL_OUT(loadC, 1) float16x4_t c0;
UNROLL_OUT(loadB, 4) #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
UNROLL_OUT(loadA0, 4) #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 4) #else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 4)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03;
__fp16 b0, b1, b2, b3; __fp16 b0, b1, b2, b3;
__fp16 c0; __fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -717,9 +587,6 @@ void megdnn::arm_common::gemv_like( ...@@ -717,9 +587,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4) UNROLL_OUT(loadB, 4)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
UNROLL_OUT(loadA0, 4)
#undef loadA0
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3; c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
...@@ -727,49 +594,59 @@ void megdnn::arm_common::gemv_like( ...@@ -727,49 +594,59 @@ void megdnn::arm_common::gemv_like(
} }
} }
for (; k + 2 <= K; k += 2) { for (; k + 2 <= K; k += 2) {
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
size_t n = 0; size_t n = 0;
for (; n + 8 <= N; n += 8) { {
float16x8_t a00, a01; #if !defined(__aarch64__)
float16x8_t b0, b1; float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01);
float16x8_t c0; #endif
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); for (; n + 8 <= N; n += 8) {
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); float16x8_t b0, b1;
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); float16x8_t c0;
UNROLL_OUT(loadC, 1) #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
UNROLL_OUT(loadB, 2) #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadA0, 2) UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 2) #else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 2)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n + 4 <= N; n += 4) { {
float16x4_t a00, a01; #if !defined(__aarch64__)
float16x4_t b0, b1; float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01);
float16x4_t c0; #endif
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); for (; n + 4 <= N; n += 4) {
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); float16x4_t b0, b1;
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); float16x4_t c0;
UNROLL_OUT(loadC, 1) #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
UNROLL_OUT(loadB, 2) #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadA0, 2) UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 2) #else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 2)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00, a01;
__fp16 b0, b1; __fp16 b0, b1;
__fp16 c0; __fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -778,9 +655,6 @@ void megdnn::arm_common::gemv_like( ...@@ -778,9 +655,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2) UNROLL_OUT(loadB, 2)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
#undef loadA0
c0 += a00 * b0 + a01 * b1; c0 += a00 * b0 + a01 * b1;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
...@@ -789,48 +663,58 @@ void megdnn::arm_common::gemv_like( ...@@ -789,48 +663,58 @@ void megdnn::arm_common::gemv_like(
} }
for (; k < K; k += 1) { for (; k < K; k += 1) {
size_t n = 0; size_t n = 0;
for (; n + 8 <= N; n += 8) { __fp16 a00 = A[m * Astride + k];
float16x8_t a00; {
float16x8_t b0; #if !defined(__aarch64__)
float16x8_t c0; float16x8_t va00 = vdupq_n_f16(a00);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); #endif
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); for (; n + 8 <= N; n += 8) {
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); float16x8_t b0;
UNROLL_OUT(loadC, 1) float16x8_t c0;
UNROLL_OUT(loadB, 1) #define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
UNROLL_OUT(loadA0, 1) #define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 1) #else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 1)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n + 4 <= N; n += 4) { {
float16x4_t a00; #if !defined(__aarch64__)
float16x4_t b0; float16x4_t va00 = vdup_n_f16(a00);
float16x4_t c0; #endif
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); for (; n + 4 <= N; n += 4) {
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); float16x4_t b0;
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); float16x4_t c0;
UNROLL_OUT(loadC, 1) #define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
UNROLL_OUT(loadB, 1) #define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT(loadA0, 1) UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#undef loadA0 #if defined(__aarch64__)
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i); #define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
UNROLL_OUT(calculate_row0, 1) #else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 1)
#undef calculate_row0 #undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i); #define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
#undef vstore #undef vstore
}
} }
for (; n < N; n += 1) { for (; n < N; n += 1) {
__fp16 a00;
__fp16 b0; __fp16 b0;
__fp16 c0; __fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n]; #define loadC(i) c##i = C[(m + i) * Cstride + n];
...@@ -839,9 +723,6 @@ void megdnn::arm_common::gemv_like( ...@@ -839,9 +723,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1) UNROLL_OUT(loadB, 1)
#undef loadB #undef loadB
#undef loadC #undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
#undef loadA0
c0 = c0 + a00 * b0; c0 = c0 + a00 * b0;
#define vstore(i) C[(m + i) * Cstride + n] = c##i; #define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1) UNROLL_OUT(vstore, 1)
...@@ -850,6 +731,10 @@ void megdnn::arm_common::gemv_like( ...@@ -850,6 +731,10 @@ void megdnn::arm_common::gemv_like(
} }
} }
} }
#undef VFMA_N_F16
#undef VFMAQ_N_F16
bool megdnn::arm_common::is_hgemv_preferred( bool megdnn::arm_common::is_hgemv_preferred(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/, bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/,
size_t LDB, size_t /*LDC*/) { size_t LDB, size_t /*LDC*/) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册