提交 0612b573 编写于 作者: M Megvii Engine Team

perf(dnn/arm): optimize gevm by reducing access to memory of matrix A

GitOrigin-RevId: 89ed7bfd50114be4fc8c2c3283bf92883afc5283
上级 b622064a
......@@ -85,6 +85,18 @@ void hgemv_naive_n(
}
} // namespace
#if defined(__aarch64__)
#define VFMAQ_N_F16(a, b, n) vfmaq_n_f16(a, b, n)
#else
#define VFMAQ_N_F16(a, b, n) vaddq_f16(a, vmulq_n_f16(b, n))
#endif
#if defined(__aarch64__)
#define VFMA_N_F16(a, b, n) vfma_n_f16(a, b, n)
#else
#define VFMA_N_F16(a, b, n) vadd_f16(a, vmul_n_f16(b, n))
#endif
void megdnn::arm_common::gemv_like(
const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
......@@ -98,33 +110,30 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, 4 * sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1],
a12 = A[(m + 1) * Astride + k + 2],
a13 = A[(m + 1) * Astride + k + 3];
__fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1],
a22 = A[(m + 2) * Astride + k + 2],
a23 = A[(m + 2) * Astride + k + 3];
__fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1],
a32 = A[(m + 3) * Astride + k + 2],
a33 = A[(m + 3) * Astride + k + 3];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23,
a30, a31, a32, a33;
float16x8_t b0, b1, b2, b3;
float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4)
UNROLL_OUT(calculate_row2, 4)
......@@ -138,32 +147,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23,
a30, a31, a32, a33;
float16x4_t b0, b1, b2, b3;
float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4)
UNROLL_OUT(calculate_row2, 4)
......@@ -177,8 +172,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, a30,
a31, a32, a33;
__fp16 b0, b1, b2, b3;
__fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -187,18 +180,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
UNROLL_OUT(loadA2, 4)
UNROLL_OUT(loadA3, 4)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3;
c2 += a20 * b0 + a21 * b1 + a22 * b2 + a23 * b3;
......@@ -210,32 +191,23 @@ void megdnn::arm_common::gemv_like(
}
for (; k + 2 <= K; k += 2) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1];
__fp16 a20 = A[(m + 2) * Astride + k], a21 = A[(m + 2) * Astride + k + 1];
__fp16 a30 = A[(m + 3) * Astride + k], a31 = A[(m + 3) * Astride + k + 1];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a10, a11, a20, a21, a30, a31;
float16x8_t b0, b1;
float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2)
UNROLL_OUT(calculate_row2, 2)
......@@ -249,31 +221,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a10, a11, a20, a21, a30, a31;
float16x4_t b0, b1;
float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2)
UNROLL_OUT(calculate_row2, 2)
......@@ -287,7 +246,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a01, a10, a11, a20, a21, a30, a31;
__fp16 b0, b1;
__fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -296,18 +254,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
UNROLL_OUT(loadA2, 2)
UNROLL_OUT(loadA3, 2)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 += a00 * b0 + a01 * b1;
c1 += a10 * b0 + a11 * b1;
c2 += a20 * b0 + a21 * b1;
......@@ -319,32 +265,23 @@ void megdnn::arm_common::gemv_like(
}
for (; k < K; k += 1) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k];
__fp16 a10 = A[(m + 1) * Astride + k];
__fp16 a20 = A[(m + 2) * Astride + k];
__fp16 a30 = A[(m + 3) * Astride + k];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a10, a20, a30;
float16x8_t b0;
float16x8_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1)
UNROLL_OUT(calculate_row2, 1)
......@@ -358,31 +295,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a10, a20, a30;
float16x4_t b0;
float16x4_t c0, c1, c2, c3;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
UNROLL_OUT(loadC, 4)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1)
UNROLL_OUT(calculate_row2, 1)
......@@ -396,7 +320,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a10, a20, a30;
__fp16 b0;
__fp16 c0, c1, c2, c3;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -405,18 +328,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
UNROLL_OUT(loadA2, 1)
UNROLL_OUT(loadA3, 1)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0 = c0 + a00 * b0;
c1 = c1 + a10 * b0;
c2 = c2 + a20 * b0;
......@@ -432,24 +343,22 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, 2 * sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1],
a12 = A[(m + 1) * Astride + k + 2],
a13 = A[(m + 1) * Astride + k + 3];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a02, a03, a10, a11, a12, a13;
float16x8_t b0, b1, b2, b3;
float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4)
#undef calculate_row0
......@@ -459,23 +368,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a02, a03, a10, a11, a12, a13;
float16x4_t b0, b1, b2, b3;
float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 4)
UNROLL_OUT(calculate_row1, 4)
#undef calculate_row0
......@@ -485,7 +387,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03, a10, a11, a12, a13;
__fp16 b0, b1, b2, b3;
__fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -494,12 +395,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 4)
UNROLL_OUT(loadA1, 4)
#undef loadA0
#undef loadA1
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
c1 += a10 * b0 + a11 * b1 + a12 * b2 + a13 * b3;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
......@@ -509,24 +404,19 @@ void megdnn::arm_common::gemv_like(
}
for (; k + 2 <= K; k += 2) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
__fp16 a10 = A[(m + 1) * Astride + k], a11 = A[(m + 1) * Astride + k + 1];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a10, a11;
float16x8_t b0, b1;
float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2)
#undef calculate_row0
......@@ -536,23 +426,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a10, a11;
float16x4_t b0, b1;
float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 2)
UNROLL_OUT(calculate_row1, 2)
#undef calculate_row0
......@@ -562,7 +445,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a01, a10, a11;
__fp16 b0, b1;
__fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -571,12 +453,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
UNROLL_OUT(loadA1, 2)
#undef loadA0
#undef loadA1
c0 += a00 * b0 + a01 * b1;
c1 += a10 * b0 + a11 * b1;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
......@@ -586,24 +462,19 @@ void megdnn::arm_common::gemv_like(
}
for (; k < K; k += 1) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k];
__fp16 a10 = A[(m + 1) * Astride + k];
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a10;
float16x8_t b0;
float16x8_t c0, c1;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1)
#undef calculate_row0
......@@ -613,23 +484,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a10;
float16x4_t b0;
float16x4_t c0, c1;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
UNROLL_OUT(loadC, 2)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT(calculate_row0, 1)
UNROLL_OUT(calculate_row1, 1)
#undef calculate_row0
......@@ -639,7 +503,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for (; n < N; n += 1) {
__fp16 a00, a10;
__fp16 b0;
__fp16 c0, c1;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -648,12 +511,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
UNROLL_OUT(loadA1, 1)
#undef loadA0
#undef loadA1
c0 = c0 + a00 * b0;
c1 = c1 + a10 * b0;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
......@@ -667,48 +524,61 @@ void megdnn::arm_common::gemv_like(
memset(C + m * Cstride, 0, sizeof(__fp16) * N);
for (; k + 4 <= K; k += 4) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1],
a02 = A[m * Astride + k + 2], a03 = A[m * Astride + k + 3];
{
#if !defined(__aarch64__)
float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01),
va02 = vdupq_n_f16(a02), va03 = vdupq_n_f16(a03);
#endif
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01, a02, a03;
float16x8_t b0, b1, b2, b3;
float16x8_t c0;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 4)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
{
#if !defined(__aarch64__)
float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01),
va02 = vdup_n_f16(a02), va03 = vdup_n_f16(a03);
#endif
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01, a02, a03;
float16x4_t b0, b1, b2, b3;
float16x4_t c0;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 4)
UNROLL_OUT(loadA0, 4)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 4)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
for (; n < N; n += 1) {
__fp16 a00, a01, a02, a03;
__fp16 b0, b1, b2, b3;
__fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -717,9 +587,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 4)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
UNROLL_OUT(loadA0, 4)
#undef loadA0
c0 += a00 * b0 + a01 * b1 + a02 * b2 + a03 * b3;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1)
......@@ -727,49 +594,59 @@ void megdnn::arm_common::gemv_like(
}
}
for (; k + 2 <= K; k += 2) {
__fp16 a00 = A[m * Astride + k], a01 = A[m * Astride + k + 1];
size_t n = 0;
{
#if !defined(__aarch64__)
float16x8_t va00 = vdupq_n_f16(a00), va01 = vdupq_n_f16(a01);
#endif
for (; n + 8 <= N; n += 8) {
float16x8_t a00, a01;
float16x8_t b0, b1;
float16x8_t c0;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 2)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
{
#if !defined(__aarch64__)
float16x4_t va00 = vdup_n_f16(a00), va01 = vdup_n_f16(a01);
#endif
for (; n + 4 <= N; n += 4) {
float16x4_t a00, a01;
float16x4_t b0, b1;
float16x4_t c0;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 2)
UNROLL_OUT(loadA0, 2)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 2)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
for (; n < N; n += 1) {
__fp16 a00, a01;
__fp16 b0, b1;
__fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -778,9 +655,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 2)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT(loadA0, 2)
#undef loadA0
c0 += a00 * b0 + a01 * b1;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1)
......@@ -789,48 +663,58 @@ void megdnn::arm_common::gemv_like(
}
for (; k < K; k += 1) {
size_t n = 0;
__fp16 a00 = A[m * Astride + k];
{
#if !defined(__aarch64__)
float16x8_t va00 = vdupq_n_f16(a00);
#endif
for (; n + 8 <= N; n += 8) {
float16x8_t a00;
float16x8_t b0;
float16x8_t c0;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 1)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
{
#if !defined(__aarch64__)
float16x4_t va00 = vdup_n_f16(a00);
#endif
for (; n + 4 <= N; n += 4) {
float16x4_t a00;
float16x4_t b0;
float16x4_t c0;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT(loadC, 1)
UNROLL_OUT(loadB, 1)
UNROLL_OUT(loadA0, 1)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT(calculate_row0, 1)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT(vstore, 1)
#undef vstore
}
}
for (; n < N; n += 1) {
__fp16 a00;
__fp16 b0;
__fp16 c0;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
......@@ -839,9 +723,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT(loadB, 1)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT(loadA0, 1)
#undef loadA0
c0 = c0 + a00 * b0;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT(vstore, 1)
......@@ -850,6 +731,10 @@ void megdnn::arm_common::gemv_like(
}
}
}
#undef VFMA_N_F16
#undef VFMAQ_N_F16
bool megdnn::arm_common::is_hgemv_preferred(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/,
size_t LDB, size_t /*LDC*/) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册