From 1ebac1c0a726af3c3562e8bb4d5600c991e08090 Mon Sep 17 00:00:00 2001 From: TianXiaogang Date: Tue, 3 Dec 2019 15:43:11 +0800 Subject: [PATCH] Armv8 4x4 gemm (#2528) * feat: add sgemm4x4 for armv8 * fix: fix armv7 gemm choose condition --- lite/backends/arm/math/packed_sgemm.cc | 644 ++++++++++++++++++++++++- 1 file changed, 625 insertions(+), 19 deletions(-) diff --git a/lite/backends/arm/math/packed_sgemm.cc b/lite/backends/arm/math/packed_sgemm.cc index 0d6eed9904..092e6937c4 100644 --- a/lite/backends/arm/math/packed_sgemm.cc +++ b/lite/backends/arm/math/packed_sgemm.cc @@ -53,6 +53,38 @@ void sgemm_prepacked_8x12(bool is_transB, bool has_bias, bool has_relu, ARMContext *ctx); + +void pack_m4(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void pack_trans_m4(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); +void sgemm_prepacked_4x4(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx); #else // for kA72 void prepackA_6x8(float *out, @@ -139,13 +171,21 @@ void prepackA(float *out, bool is_trans, ARMContext *ctx) { #ifdef __aarch64__ - if (is_trans) { - prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + if (mmax <= 4) { + if (is_trans) { + pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + } } else { - prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + if (is_trans) { + prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } } #else - if (ctx->arch() == kA73) { + if (ctx->arch() == kA73 || mmax <= 4) { if (is_trans) { prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); } else { @@ -212,22 +252,39 @@ void sgemm_prepack(bool is_transB, bool has_relu, ARMContext *ctx) { #ifdef __aarch64__ - sgemm_prepacked_8x12(is_transB, - M, - N, - K, - A_packed, - B, - ldb, - beta, - C, - ldc, - bias, - has_bias, - has_relu, - ctx); + if (M <= 4) { + sgemm_prepacked_4x4(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + has_relu, + ctx); + } else { + sgemm_prepacked_8x12(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + has_relu, + ctx); + } #else // armv7 - if (ctx->arch() == kA73) { + if (ctx->arch() == kA73 || M <= 4) { sgemm_prepacked_4x8(is_transB, M, N, @@ -522,6 +579,147 @@ void prepackA_8x12(float *dout, } } } +void pack_m4(float *dout, + const float *inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + int stride = x_len * 4; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + +#pragma omp parallel for + for (int y = m0; y < mmax; y += 4) { + float *outptr = dout + stride * (y - m0) / 4; + + const float *inptr0 = inptr + y * ldin + k0; + const float *inptr1 = inptr0 + ldin; + const float *inptr2 = inptr1 + ldin; + const float *inptr3 = inptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + "cbz %w[has_alpha], 0f\n" /* check alpha == 1.f? */ + "dup v31.4s, %w[alpha]\n" /* alpha to vector */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "fmul v0.4s, v31.4s, v0.4s\n" /* mul alpha */ + "fmul v1.4s, v31.4s, v1.4s\n" /* mul alpha */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "fmul v2.4s, v31.4s, v2.4s\n" /* mul alpha */ + "fmul v3.4s, v31.4s, v3.4s\n" /* mul alpha */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "fmul v4.4s, v31.4s, v4.4s\n" /* mul alpha */ + "fmul v5.4s, v31.4s, v5.4s\n" /* mul alpha */ + "fmul v6.4s, v31.4s, v6.4s\n" /* mul alpha */ + "fmul v7.4s, v31.4s, v7.4s\n" /* mul alpha */ + "b 1f\n" /* to main process */ + "0: \n" /* alpha == 1 */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "1: \n" /* main process */ + "trn1 v8.4s, v0.4s, v2.4s\n" /* a0b0a2b2*/ + "trn2 v9.4s, v0.4s, v2.4s\n" /* a1b1a3b3*/ + "trn1 v10.4s, v1.4s, v3.4s\n" /* a4b4a6b6*/ + "trn2 v11.4s, v1.4s, v3.4s\n" /* a5b5a7b7*/ + + "trn1 v12.4s, v4.4s, v6.4s\n" /* c0d0c2d2*/ + "trn2 v13.4s, v4.4s, v6.4s\n" /* c1d1c3d3*/ + "trn1 v14.4s, v5.4s, v7.4s\n" /* c4d4c6d6*/ + "trn2 v15.4s, v5.4s, v7.4s\n" /* c5d5c7d7*/ + + "trn1 v0.2d, v8.2d, v12.2d\n" /* a0b0c0d0 */ + "trn1 v1.2d, v9.2d, v13.2d\n" /* a1b1c1d1 */ + "trn1 v2.2d, v10.2d, v14.2d\n" /* a4b4c4d4 */ + "trn1 v3.2d, v11.2d, v15.2d\n" /* a5b5c5d5 */ + + "trn2 v4.2d, v8.2d, v12.2d\n" /* a2b2c2d2 */ + "trn2 v5.2d, v9.2d, v13.2d\n" /* a3b3c3d3 */ + "stp q0, q1, [%[outptr]], #32\n" /* save q0, q1, a0~h0*/ + "trn2 v6.2d, v10.2d, v14.2d\n" /* a6b6c6d6 */ + "trn2 v7.2d, v11.2d, v15.2d\n" /* a7b7c7d7 */ + "stp q4, q5, [%[outptr]], #32\n" /* save q2, q3, a1~h1*/ + "stp q2, q3, [%[outptr]], #32\n" /* save q4, q5, a2~h2*/ + "stp q6, q7, [%[outptr]], #32\n" /* save q6, q7, a3~h3*/ + + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : [alpha] "r"(alpha), [has_alpha] "r"(has_alpha) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "cc", + "memory"); + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } + } +} void prepackA_trans_8x12(float *outptr, const float *in, @@ -682,6 +880,128 @@ void prepackA_trans_8x12(float *outptr, } } } +void pack_trans_m4(float *outptr, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + uint32_t mask_buffer[4] = {0, 1, 2, 3}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int stride_out = 4 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float *ptr0 = inptr + y * ldin; + const float *ptr1 = ptr0 + ldin; + const float *ptr2 = ptr1 + ldin; + const float *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr30 = vld1q_f32(ptr3); + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr30 = vmulq_f32(vr30, valpha); + } + + vst1q_f32(outptr_row_col, vr00); + vst1q_f32(outptr_row_col + 4, vr10); + vst1q_f32(outptr_row_col + 8, vr20); + vst1q_f32(outptr_row_col + 12, vr30); + + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + ptr3 += 4; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr30 = vld1q_f32(ptr3); + + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr30 = vmulq_f32(vr30, valpha); + } + + float32x4_t vr00_1 = vbslq_f32(vmask1, vr00, vzero); + float32x4_t vr10_1 = vbslq_f32(vmask1, vr10, vzero); + float32x4_t vr20_1 = vbslq_f32(vmask1, vr20, vzero); + float32x4_t vr30_1 = vbslq_f32(vmask1, vr30, vzero); + + vst1q_f32(outptr_row_col, vr00_1); + vst1q_f32(outptr_row_col + 4, vr10_1); + vst1q_f32(outptr_row_col + 8, vr20_1); + vst1q_f32(outptr_row_col + 12, vr30_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float *ptr0 = inptr + y * ldin; + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float32x4_t vr0 = vld1q_f32(ptr0); + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + } + vst1q_f32(outptr_row_col, vr0); + + ptr0 += 4; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr0 = vld1q_f32(ptr0); + + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + } + + float32x4_t vr0_1 = vbslq_f32(vmask1, vr0, vzero); + + vst1q_f32(outptr_row_col, vr0_1); + } + } +} #else // __aarch64__ void prepackA_6x8(float* outptr, @@ -2592,6 +2912,292 @@ void sgemm_prepacked_8x12(bool is_transB, } } } + +void sgemm_prepacked_4x4(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto workspace = ctx->workspace_data(); + int threads = ctx->threads(); + + const int n_block = 4; + const int m_block = 4; + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block)); + x_block /= n_block; + x_block *= n_block; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + n_block - 1) / n_block; + x_block *= n_block; + x_block = x_block < n_block ? n_block : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + n_block - 1) / n_block; + remain = xmax - x0 - (bblocks - 1) * n_block; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = workspace; + if (is_transB) { + pack_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K); + } else { + pack_trans_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += m_block) { + unsigned int ymax = y + m_block; + if (ymax > M) { + ymax = M; + } + + float bias_local[4] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float cout0[n_block]; // NOLINT + float cout1[n_block]; // NOLINT + float cout2[n_block]; // NOLINT + float cout3[n_block]; // NOLINT + + float *c_ptr0 = C + y * ldc + x0; + float *c_ptr1 = c_ptr0 + ldc; + float *c_ptr2 = c_ptr1 + ldc; + float *c_ptr3 = c_ptr2 + ldc; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr_l = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout2; + case 0: + c_ptr3 = cout3; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + } + } + } + const float *a_ptr = a_ptr_l; + const float *b_ptr = b_ptr_l + xb * K * 4; + int tail = tail_pre; + int k = k_pre; + // clang-format off + asm volatile( + "prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/ + "ld1 {v2.4s}, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "prfm pldl1keep, [%[b_ptr]]\n" /* preload b*/ + "dup v9.4s, v2.s[1]\n" /* out1 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v10.4s, v2.s[2]\n" /* out2 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v11.4s, v2.s[3]\n" /* out3 = 0*/ + "cbz %w[has_beta], 0f\n" /* check beta == 0? */ + /* process beta */ + "dup v7.4s, %w[beta]\n" /* beta to vector */ + "ld1 {v0.4s}, [%[c_ptr0]]\n" /* load output r0 */ + "ld1 {v1.4s}, [%[c_ptr1]]\n" /* load output r1 */ + "fmla v8.4s, v0.4s, v7.4s\n" /* cr00 += beta * c_r00*/ + "fmla v9.4s, v1.4s, v7.4s\n" /* cr10 += beta * c_r10*/ + "ld1 {v2.4s}, [%[c_ptr2]]\n" + "ld1 {v3.4s}, [%[c_ptr3]]\n" + "fmla v10.4s, v2.4s, v7.4s\n" /* cr20 += beta * c_r20*/ + "fmla v11.4s, v3.4s, v7.4s\n" /* cr30 += beta * c_r30*/ + + "0: \n" /* check loop count */ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a10 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */ + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5 */ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b2 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b2 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b2 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b2 =q6*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a10 to q0, q1 */ + + "fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/ + "fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/ + "fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/ + + "bne 1b\n" + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */ + + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */ + + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/ + + /* unrool 3, tail = 4*/ + + "fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/ + "fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/ + "fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/ + "fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/ + + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/ + + "11: \n" /* check if relu */ + "cbz %w[relu], 12f\n" /* skip relu */ + "movi v2.4s, #0\n" /* for relu*/ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ + "12: \n" + "st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */ + "st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */ + "st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */ + "st1 {v11.4s}, [%[c_ptr3]], #16\n" /* store r3 */ + + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3) + : [bias_ptr] "r"(bias_local), + [relu] "r"(has_relu), + [has_beta] "r"(has_beta), + [beta] "r"(beta) + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11"); + // clang-format on + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} #else // __aarch64__ /** * \brief gemm with ablock = 6, bblock = 8, output 6x8 -- GitLab