From 47d21d28dcc78a728ac05571034286bd96b3b13d Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Wed, 2 Sep 2020 14:31:42 +0800 Subject: [PATCH] [arm] add v8-4x8 gemm implement (#4201) * add v8 4x8 implment. test=devvelop * fix run error * change sgemm compute. test=develop * fix format. test=develop --- lite/backends/arm/math/packed_sgemm.cc | 1491 ++++++++++++++++++++---- lite/backends/arm/math/sgemm.cc | 6 +- lite/tests/math/sgemm_compute_test.cc | 2 +- 3 files changed, 1256 insertions(+), 243 deletions(-) diff --git a/lite/backends/arm/math/packed_sgemm.cc b/lite/backends/arm/math/packed_sgemm.cc index 2e869f2df3..c431c5651d 100644 --- a/lite/backends/arm/math/packed_sgemm.cc +++ b/lite/backends/arm/math/packed_sgemm.cc @@ -55,6 +55,39 @@ void sgemm_prepacked_8x12(bool is_transB, const operators::ActivationParam act_param, ARMContext *ctx); +void prepackA_4x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_trans_4x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void sgemm_prepacked_4x8(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + const operators::ActivationParam act_param, + ARMContext *ctx); + void pack_m4(float *out, const float *in, float alpha, @@ -189,9 +222,9 @@ void prepackA(float *out, #ifdef __aarch64__ if (mmax <= 4) { if (is_trans) { - pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); } else { - pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + prepackA_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); } } else { if (is_trans) { @@ -269,7 +302,7 @@ void sgemm_prepack(bool is_transB, ARMContext *ctx) { #ifdef __aarch64__ if (M <= 4) { - sgemm_prepacked_4x4(is_transB, + sgemm_prepacked_4x8(is_transB, M, N, K, @@ -633,6 +666,145 @@ void prepackA_8x12(float *dout, } } } + +void prepackA_4x8(float *outptr, + const float *inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); +#pragma omp parallel for + for (int y = m0; y < mmax; y += 4) { + const float *inptr0 = inptr + y * ldin + k0; + const float *inptr1 = inptr0 + ldin; + const float *inptr2 = inptr1 + ldin; + const float *inptr3 = inptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3) + : "memory"); + int x = x_len; + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + // clang-format off + asm volatile( + "cbz %w[has_alpha], 0f\n" + "ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7 + "ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7 + "fmul v0.4s, v0.4s, %[alpha].4s\n" + "fmul v1.4s, v1.4s, %[alpha].4s\n" + "ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7 + "fmul v2.4s, v2.4s, %[alpha].4s\n" + "fmul v3.4s, v3.4s, %[alpha].4s\n" + "ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7 + "fmul v4.4s, v4.4s, %[alpha].4s\n" + "fmul v5.4s, v5.4s, %[alpha].4s\n" + "fmul v6.4s, v6.4s, %[alpha].4s\n" + "fmul v7.4s, v7.4s, %[alpha].4s\n" + "b 1f\n" // to main process + "0: \n" // alpha == 1 + "ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7 + "ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7 + "ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7 + "ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7 + "1: \n" + "trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2 + "trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3 + "trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6 + "trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7 + "trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2 + "trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3 + "trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6 + "trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7 + + "trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0 + "trn1 v1.2d, v9.2d, v13.2d\n" // a1b1c1d1 + "trn2 v2.2d, v8.2d, v12.2d\n" // a2b2c2d2 + "trn2 v3.2d, v9.2d, v13.2d\n" // a3b3c3d3 + "st1 {v0.4s}, [%[outptr]], #16\n" + "trn1 v4.2d, v10.2d, v14.2d\n" // a4b4c4d4 + "st1 {v1.4s}, [%[outptr]], #16\n" + "trn1 v5.2d, v11.2d, v15.2d\n" // a5b5c5d5 + "st1 {v2.4s}, [%[outptr]], #16\n" + "trn2 v6.2d, v10.2d, v14.2d\n" // a6b6c6d6 + "st1 {v3.4s}, [%[outptr]], #16\n" + "trn2 v7.2d, v11.2d, v15.2d\n" // a7b7c7d7 + "stp q4, q5, [%[outptr]], #32\n" + "stp q6, q7, [%[outptr]], #32\n" + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + // clang-format on + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } + } +} void pack_m4(float *dout, const float *inptr, float alpha, @@ -934,6 +1106,159 @@ void prepackA_trans_8x12(float *outptr, } } } + +void prepackA_trans_4x8(float *outptr, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int right_pad = 4 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + int stride_out = 4 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float *ptr0 = inptr + y * ldin; + const float *ptr1 = ptr0 + ldin; + const float *ptr2 = ptr1 + ldin; + const float *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float *ptr_out = outptr_row_col; + // clang-format off + asm volatile( + "cmp %w[has_alpha], #0\n" + "ld1 {v0.4s}, [%[ptr0]], #16\n" + "ld1 {v1.4s}, [%[ptr1]], #16\n" + "ld1 {v2.4s}, [%[ptr2]], #16\n" + "ld1 {v3.4s}, [%[ptr3]], #16\n" + "beq 0f\n" + "1: \n" + "fmul v0.4s, v0.4s, %[alpha].4s\n" + "fmul v1.4s, v1.4s, %[alpha].4s\n" + "fmul v2.4s, v2.4s, %[alpha].4s\n" + "fmul v3.4s, v3.4s, %[alpha].4s\n" + "0: \n" + "stp q0, q1, [%[outptr]], #32\n" + "stp q2, q3, [%[outptr]], #32\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "v0", "v1", "v2", "v3", "cc", "memory"); + // clang-format on + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float *ptr_out = outptr_row_col; + // clang-format off + asm volatile( + "cmp %w[has_alpha], #0\n" + "ld1 {v0.4s}, [%[ptr0]], #16\n" + "ld1 {v1.4s}, [%[ptr1]], #16\n" + "ld1 {v2.4s}, [%[ptr2]], #16\n" + "ld1 {v3.4s}, [%[ptr3]], #16\n" + "beq 0f\n" + "1: \n" + "fmul v0.4s, v0.4s, %[alpha].4s\n" + "fmul v1.4s, v1.4s, %[alpha].4s\n" + "fmul v2.4s, v2.4s, %[alpha].4s\n" + "fmul v3.4s, v3.4s, %[alpha].4s\n" + "0: \n" + "bif v0.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v1.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v2.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v3.16b, %[vzero].16b, %[vmask1].16b\n" + "stp q0, q1, [%[outptr]], #32\n" + "stp q2, q3, [%[outptr]], #32\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "v0", "v1", "v2", "v3", "cc", "memory"); + // clang-format on + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float *ptr0 = inptr + y * ldin; + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float *ptr_out = outptr_row_col; + asm volatile( + "cmp %[has_alpha], #0\n" + "ld1 {v0.4s}, [%[ptr0]], #16\n" + "beq 0f\n" + "1: \n" + "fmul v0.4s, v0.4s, %[alpha].4s\n" + "0: \n" + "st1 {v0.4s}, [%[outptr]], #16\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "v0", "v1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float *ptr_out = outptr_row_col; + asm volatile( + "cmp %w[has_alpha], #0\n" + "ld1 {v0.4s}, [%[ptr0]], #16\n" + "beq 0f\n" + "1: \n" + "fmul v0.4s, v0.4s, %[alpha].4s\n" + "0: \n" + "bif v0.16b, %[vzero].16b, %[vmask1].16b\n" + "st1 {v0.4s}, [%[outptr]], #16\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "v0", "v1", "cc", "memory"); + } + } +} + void pack_trans_m4(float *outptr, const float *in, float alpha, @@ -1587,6 +1912,7 @@ void prepackA_trans_4x8(float* outptr, int i = 0; for (; i < x_len - 3; i += 4) { float* ptr_out = outptr_row_col; + // clang-format off asm volatile( "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" "cmp %[has_alpha], #0\n" @@ -1597,10 +1923,12 @@ void prepackA_trans_4x8(float* outptr, : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) : "q0", "q1", "cc", "memory"); + // clang-format on outptr_row_col += stride_out; } if (right_pad > 0) { float* ptr_out = outptr_row_col; + // clang-format off asm volatile( "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" "cmp %[has_alpha], #0\n" @@ -1615,6 +1943,7 @@ void prepackA_trans_4x8(float* outptr, [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) : "q0", "q1", "cc", "memory"); + // clang-format on } } } @@ -1624,7 +1953,7 @@ void prepackA_trans_4x8(float* outptr, /** * \brief input data is transpose * for arm-v7a, transform data to block x k x 8 layout -* for arm-v8a, transform data to block x k x 12 layout +* for arm-v8a, transform data to block x k x 12 layout or block x k x 8 layout */ #ifdef __aarch64__ void loadb( @@ -1996,9 +2325,288 @@ void loadb_trans( "zip1 v18.4s, v16.4s, v17.4s\n" /* i6j6k6l6 */ "zip2 v19.4s, v16.4s, v17.4s\n" /* i7j7k7l7 */ - "str q18, [%[outptr]], #16\n" /* save i6~l6 */ - "stp q22, q23, [%[outptr]], #32\n" /* save a7~h7 */ - "str q19, [%[outptr]], #16\n" /* save i7~l7 */ + "str q18, [%[outptr]], #16\n" /* save i6~l6 */ + "stp q22, q23, [%[outptr]], #32\n" /* save a7~h7 */ + "str q19, [%[outptr]], #16\n" /* save i7~l7 */ + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), + [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), + [inptr7] "+r"(inptr7), + [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), + [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) + : + : "v0","v1","v2","v3","v4","v5", + "v6","v7","v8","v9","v10","v11","v12", + "v13","v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25","v26", + "v27","v28","v29","v30","v31","cc","memory"); + // clang-format on + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } + } +} +void loadb_eight( + float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) { + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + + uint32_t *outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + uint32_t *outptr_row_col = outptr_row + y * 8; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t *ptr_out = outptr_row_col; + asm volatile( + "ldp q0, q1, [%[ptr0]], #32\n" // load r0, 8 elements + "ldp q2, q3, [%[ptr1]], #32\n" // load r1, 8 elements + "stp q0, q1, [%[outptr]], #32\n" // write to output ptr + "stp q2, q3, [%[outptr]], #32\n" // write to output ptr + "ldp q0, q1, [%[ptr2]], #32\n" // load r0, 8 elements + "ldp q2, q3, [%[ptr3]], #32\n" // load r1, 8 elements + "stp q0, q1, [%[outptr]], #32\n" // write to output ptr + "stp q2, q3, [%[outptr]], #32\n" // write to output ptr + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : + : "v0", "v1", "v2", "v3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t *ptr_out = outptr_row_col; + asm volatile( + "ldp q0, q1, [%[ptr0]], #32\n" + "ldp q2, q3, [%[ptr1]], #32\n" + "bif v0.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v1.16b, %[vzero].16b, %[vmask2].16b\n" + "bif v2.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v3.16b, %[vzero].16b, %[vmask2].16b\n" + "stp q0, q1, [%[outptr]], #32\n" + "ldp q0, q1, [%[ptr2]], #32\n" + "stp q2, q3, [%[outptr]], #32\n" + "ldp q2, q3, [%[ptr3]], #32\n" + "bif v0.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v1.16b, %[vzero].16b, %[vmask2].16b\n" + "bif v2.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v3.16b, %[vzero].16b, %[vmask2].16b\n" + "stp q0, q1, [%[outptr]], #32\n" + "stp q2, q3, [%[outptr]], #32\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "v0", "v1", "v2", "v3", "cc", "memory"); + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t *ptr_out = outptr_row_col; + asm volatile( + "ldp q0, q1, [%[ptr0]], #32\n" + "stp q0, q1, [%[outptr]], #32\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "v0", "v1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t *ptr_out = outptr_row_col; + asm volatile( + "ldp q0, q1, [%[ptr0]], #32\n" + "bif v0.16b, %[vzero].16b, %[vmask1].16b\n" + "bif v1.16b, %[vzero].16b, %[vmask2].16b\n" + "stp q0, q1, [%[outptr]], #32\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "v0", "v1", "cc", "memory"); + } + } +} + +void loadb_trans_eight( + float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in); + //! data B is not transposed, transpose B to k * 8 + for (int y = n0; y < nmax; y += 8) { + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + + int x = x_len; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), + [ptr4] "r"(inptr4), + [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), + [ptr7] "r"(inptr7) + : "memory"); + + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= nmax) { + switch ((y + 7) - nmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + // clang-format off + //! zip load 8 elements (2 neon Q registers) from each of 8 rows + asm volatile( + "ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7 + "ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7 + "ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7 + "ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7 + "trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2 + "trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3 + "trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6 + "trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7 + "ldp q16, q17, [%[inptr4]], #32\n"// load r4, e0~e7 + "ldp q18, q19, [%[inptr5]], #32\n"// load r5, f0~f7 + "ldp q20, q21, [%[inptr6]], #32\n"// load r6, g0~g7 + "ldp q22, q23, [%[inptr7]], #32\n"// load r7, h0~h7 + "trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2 + "trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3 + "trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6 + "trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7 + + "trn1 v24.4s, v16.4s, v18.4s\n" // e0f0e2f2 + "trn2 v25.4s, v16.4s, v18.4s\n" // e1f1e3f3 + "trn1 v28.4s, v20.4s, v22.4s\n" // g0h0e2f2 + "trn2 v29.4s, v20.4s, v22.4s\n" // g1h1e3f3 + "trn1 v26.4s, v17.4s, v19.4s\n" // e4f4e6f6 + "trn2 v27.4s, v17.4s, v19.4s\n" // e5f5e7f7 + "trn1 v30.4s, v21.4s, v23.4s\n" // g4h4e6f6 + "trn2 v31.4s, v21.4s, v23.4s\n" // g5h5e7f7 + + "trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0 + "trn1 v1.2d, v24.2d, v28.2d\n" // e0f0g0h0 + "trn1 v2.2d, v9.2d, v13.2d\n" // a1b1c1d1 + "trn1 v3.2d, v25.2d, v29.2d\n" // e1f1g1h1 + "trn2 v4.2d, v8.2d, v12.2d\n" // a2b2c2d2 + "trn2 v5.2d, v24.2d, v28.2d\n" // e2f2g2h2 + "stp q0, q1, [%[outptr]], #32\n" // save q0, q1, a0~h0 + "trn2 v6.2d, v9.2d, v13.2d\n" // a3b3c3d3 + "trn2 v7.2d, v25.2d, v29.2d\n" // e3f3g3h3 + "stp q2, q3, [%[outptr]], #32\n" // save q0, q1, a1~h1 + + "trn1 v16.2d, v10.2d, v14.2d\n" // a4b4c4d4 + "trn1 v17.2d, v26.2d, v30.2d\n" // e4f4g4h4 + "stp q4, q5, [%[outptr]], #32\n" // save q0, q1, a2~h2 + "trn1 v18.2d, v11.2d, v15.2d\n" // a5b5c5d5 + "trn1 v19.2d, v27.2d, v31.2d\n" // e5f5g5h5 + "stp q6, q7, [%[outptr]], #32\n" // save q0, q1, a3~h3 + "trn2 v20.2d, v10.2d, v14.2d\n" // a6b6c6d6 + "trn2 v21.2d, v26.2d, v30.2d\n" // e6f6g6h6 + "stp q16, q17, [%[outptr]], #32\n" // save q0, q1, a4~h4 + "trn2 v22.2d, v11.2d, v15.2d\n" // a7b7c7d7 + "trn2 v23.2d, v27.2d, v31.2d\n" // e7f7g7h7 + "stp q18, q19, [%[outptr]], #32\n" // save q0, q1, a5~h5 + "stp q20, q21, [%[outptr]], #32\n" // save q0, q1, a6~h6 + "stp q22, q23, [%[outptr]], #32\n" // save q0, q1, a7~h7 : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), @@ -2007,10 +2615,6 @@ void loadb_trans( [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [inptr8] "+r"(inptr8), - [inptr9] "+r"(inptr9), - [inptr10] "+r"(inptr10), - [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) : : "v0","v1","v2","v3","v4","v5", @@ -2030,10 +2634,6 @@ void loadb_trans( *outptr++ = *inptr5++; *outptr++ = *inptr6++; *outptr++ = *inptr7++; - *outptr++ = *inptr8++; - *outptr++ = *inptr9++; - *outptr++ = *inptr10++; - *outptr++ = *inptr11++; } } } @@ -2943,11 +3543,11 @@ void sgemm_prepacked_8x12(bool is_transB, "fmax v30.4s, v30.4s, v0.4s \n" /* relu*/ "fmax v31.4s, v31.4s, v0.4s \n" /* relu*/ "b 20f \n" /* relu end */ - //! no act + //! no act "12: \n" /* no relu */ "cmp %w[flag_act], #0 \n" /* check no act */ - "beq 20f \n" /* no act end */ - //! relu6 + "beq 20f \n" /* no act end */ + //! relu6 "cmp %w[flag_act], #2 \n" /* check if has relu6 */ "bne 13f \n" /* jump if no relu6 */ "movi v0.4s, #0 \n" /* for relu6 */ @@ -3005,77 +3605,77 @@ void sgemm_prepacked_8x12(bool is_transB, "13: \n" /* otherwise is leakey relu */ "movi v0.4s, #0 \n" /* for leakey relu */ "ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */ - "fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */ - "bif v8.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v9.16b, v5.16b, v4.16b \n" /* choose*/ + "fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */ + "bif v8.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v9.16b, v5.16b, v4.16b \n" /* choose*/ "bif v10.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v11.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v11.4s, v1.4s \n" /* vmulq_f32 */ "bif v11.16b, v3.16b, v2.16b \n" /* choose*/ - "fcmge v2.4s, v12.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v12.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v13.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v13.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v14.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v14.4s, v1.4s \n" /* vmulq_f32 */ - "bif v12.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v13.16b, v5.16b, v4.16b \n" /* choose*/ + "fcmge v2.4s, v12.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v12.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v13.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v13.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v14.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v14.4s, v1.4s \n" /* vmulq_f32 */ + "bif v12.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v13.16b, v5.16b, v4.16b \n" /* choose*/ "bif v14.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v15.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v15.4s, v1.4s \n" /* vmulq_f32 */ "bif v15.16b, v3.16b, v2.16b \n" /* choose*/ - "fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */ - "bif v16.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v17.16b, v5.16b, v4.16b \n" /* choose*/ + "fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */ + "bif v16.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v17.16b, v5.16b, v4.16b \n" /* choose*/ "bif v18.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v19.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v19.4s, v1.4s \n" /* vmulq_f32 */ "bif v19.16b, v3.16b, v2.16b \n" /* choose*/ - "fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */ - "bif v20.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v21.16b, v5.16b, v4.16b \n" /* choose*/ - "bif v22.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v23.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */ + "bif v20.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v21.16b, v5.16b, v4.16b \n" /* choose*/ + "bif v22.16b, v7.16b, v6.16b \n" /* choose*/ + "fcmge v2.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v23.4s, v1.4s \n" /* vmulq_f32 */ "bif v23.16b, v3.16b, v2.16b \n" /* choose*/ - "fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */ - "bif v24.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v25.16b, v5.16b, v4.16b \n" /* choose*/ + "fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */ + "bif v24.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v25.16b, v5.16b, v4.16b \n" /* choose*/ "bif v26.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v27.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v27.4s, v1.4s \n" /* vmulq_f32 */ "bif v27.16b, v3.16b, v2.16b \n" /* choose*/ - "fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */ - "bif v28.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v29.16b, v5.16b, v4.16b \n" /* choose*/ + "fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */ + "bif v28.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v29.16b, v5.16b, v4.16b \n" /* choose*/ "bif v30.16b, v7.16b, v6.16b \n" /* choose*/ - "fcmge v2.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v31.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v2.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v31.4s, v1.4s \n" /* vmulq_f32 */ "bif v31.16b, v3.16b, v2.16b \n" /* choose*/ "20: \n" /* act end */ @@ -3103,7 +3703,7 @@ void sgemm_prepacked_8x12(bool is_transB, : [bias_ptr] "r"(bias_local), [has_beta] "r"(has_beta), [beta] "r"(beta), - [alpha] "r"(alpha), + [alpha] "r"(alpha), [flag_act] "r"(flag_act) : "cc","memory", "v0","v1","v2","v3","v4","v5","v6","v7", @@ -3129,6 +3729,419 @@ void sgemm_prepacked_8x12(bool is_transB, } } +void sgemm_prepacked_4x8(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + const operators::ActivationParam act_param, + ARMContext *ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto *workspace = ctx->workspace_data(); + int threads = ctx->threads(); + auto act_type = act_param.active_type; + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + int flag_act = 0x00; // relu: 1, relu6: 2, leaky: 4 + const int n_block = 8; + const int m_block = 4; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 0x01; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 0x02; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 0x03; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } + float32x4_t valpha = vld1q_f32(alpha); + float32x4_t vzero = vdupq_n_f32(0.f); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block)); + x_block /= n_block; + x_block *= n_block; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + n_block - 1) / n_block; + x_block *= n_block; + x_block = x_block < n_block ? n_block : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + n_block - 1) / n_block; + remain = xmax - x0 - (bblocks - 1) * n_block; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + auto b_pannel = static_cast(workspace); + if (is_transB) { + loadb_trans_eight(b_pannel, B, ldb, 0, K, x0, xmax); + } else { + loadb_eight(b_pannel, B, ldb, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += m_block) { + unsigned int ymax = y + m_block; + if (ymax > M) { + ymax = M; + } + float cout0[n_block]; // NOLINT + float cout1[n_block]; // NOLINT + float cout2[n_block]; // NOLINT; + float cout3[n_block]; // NOLINT + + float bias_local[4] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float *c_ptr0 = C + y * ldc + x0; + float *c_ptr1 = c_ptr0 + ldc; + float *c_ptr2 = c_ptr1 + ldc; + float *c_ptr3 = c_ptr2 + ldc; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout1; + case 0: + c_ptr3 = cout1; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + } + } + } + const float *a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + // clang-format off + asm volatile( + "ld1 {v2.4s}, [%[bias_ptr]]\n" + "dup v8.4s, v2.s[0]\n" + "prfm pldl1keep, [%[a_ptr]]\n" + "dup v9.4s, v2.s[0]\n" + "prfm pldl1keep, [%[b_ptr]]\n" + "dup v10.4s, v2.s[1]\n" + "dup v11.4s, v2.s[1]\n" + "prfm pldl1keep, [%[a_ptr], #64]\n" + "dup v12.4s, v2.s[2]\n" + "dup v13.4s, v2.s[2]\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + "dup v14.4s, v2.s[3]\n" + "dup v15.4s, v2.s[3]\n" + "prfm pldl1keep, [%[a_ptr], #128]\n" + "cmp %w[beta], #0\n" // check beta == 0? + "prfm pldl1keep, [%[b_ptr], #128]\n" + "prfm pldl1keep, [%[b_ptr], #192]\n" + // process beta + "beq 11f\n" + "dup v7.4s, %w[beta]\n" // beta to vector + "ld1 {v0.4s, v1.4s}, [%[c_ptr0]]\n" // load output r0 + "ld1 {v2.4s, v3.4s}, [%[c_ptr1]]\n" // load output r1 + "fmla v8.4s, v0.4s, v7.4s\n" // cr00 += beta * c_r00 + "fmla v9.4s, v1.4s, v7.4s\n" // cr01 += beta * c_r01 + "ld1 {v0.4s, v1.4s}, [%[c_ptr2]]\n" // load output r2 + "fmla v10.4s, v2.4s, v7.4s\n" // cr10 += beta * c_r10 + "fmla v11.4s, v3.4s, v7.4s\n" // cr11 += beta * c_r11 + "ld1 {v2.4s, v3.4s}, [%[c_ptr3]]\n" // load output r3 + "fmla v12.4s, v0.4s, v7.4s\n" // cr20 += beta * c_r20 + "fmla v13.4s, v1.4s, v7.4s\n" // cr21 += beta * c_r21 + "fmla v14.4s, v2.4s, v7.4s\n" // cr30 += beta * c_r30 + "fmla v15.4s, v3.4s, v7.4s\n" // cr31 += beta * c_r31 + "11: \n" // check loop count + "ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1 + "ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5 + "cbz %w[k], 0f\n" // check loop count > 0 + // main loop + // Unroll 0 + "1: \n" + "fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1] + "ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2 + "fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2] + "fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3] + "ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3 + "fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3] + "ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5 + // Unroll 1 + "fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0] + "prfm pldl1keep, [%[b_ptr], #192]\n" + "fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3] + "fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3] + "ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2 + // Unroll 2 + "fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0] + "ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1 + "fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v4.4s, v2.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v4.4s, v2.s[3]\n" // out1 += b0 * a0[3] + "fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3] + "ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5 + // Unroll 3 + "fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0] + "prfm pldl1keep, [%[a_ptr], #128]\n" + "fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3] + "subs %w[k], %w[k], #1\n" // loop count - 1 + "fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3] + "bne 1b\n" + "0: \n" + "subs %w[tail], %w[tail], #1\n" // tail-- + "beq 3f\n" // jump to tail = 1 + // Unroll 0 + "ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2 + "fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1] + "subs %w[tail], %w[tail], #1\n" // tail-- + "fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2] + "fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3] + "fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3] + "beq 4f\n" // jump to tail = 2 + // Unroll 1 + "ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5 + "fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0] + "ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3 + "fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1] + "subs %w[tail], %w[tail], #1\n" // tail--*/ + "fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3] + "fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3] + "beq 5f\n" // jump to tail = 3 + // Unroll 2 + "ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2 + "fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2] + "fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3] + "fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3] + // Unroll 3 + "fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3] + "fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3] + "b 2f\n" + // tails==1 final tail + "3: \n" + "fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2] + "fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3] + "fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3] + // aptr - 16 + "sub %w[a_ptr], %w[a_ptr], #16\n" + "b 2f\n" + "4: \n" + // tails==2 final tail + "fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2] + "fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3] + "fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3] + "b 2f\n" + // tails==3 final tail + "5: \n" + "fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0] + "fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1] + "fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2] + "fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3] + "fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0] + "fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1] + "fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2] + "fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3] + // aptr - 16 + "sub %w[a_ptr], %w[a_ptr], #16\n" + "2: \n" + "cmp %w[flag_act], #0\n" // check if has act + "beq 10f\n" + "cmp %w[flag_act], #1\n" // check if has relu + "bne 6f\n" + "fmax v8.4s, v8.4s, %[vzero].4s\n" + "fmax v9.4s, v9.4s, %[vzero].4s\n" + "fmax v10.4s, v10.4s, %[vzero].4s\n" + "fmax v11.4s, v11.4s, %[vzero].4s\n" + "fmax v12.4s, v12.4s, %[vzero].4s\n" + "fmax v13.4s, v13.4s, %[vzero].4s\n" + "fmax v14.4s, v14.4s, %[vzero].4s\n" + "fmax v15.4s, v15.4s, %[vzero].4s\n" + "b 10f\n" // end + "6: \n" + "cmp %w[flag_act], #2\n" // check relu6 + "bne 7f\n" + "fmax v8.4s, v8.4s, %[vzero].4s\n" + "fmax v9.4s, v9.4s, %[vzero].4s\n" + "fmax v10.4s, v10.4s, %[vzero].4s\n" + "fmax v11.4s, v11.4s, %[vzero].4s\n" + "fmax v12.4s, v12.4s, %[vzero].4s\n" + "fmax v13.4s, v13.4s, %[vzero].4s\n" + "fmax v14.4s, v14.4s, %[vzero].4s\n" + "fmax v15.4s, v15.4s, %[vzero].4s\n" + "fmin v8.4s, v8.4s, %[valpha].4s\n" + "fmin v9.4s, v9.4s, %[valpha].4s\n" + "fmin v10.4s, v10.4s, %[valpha].4s\n" + "fmin v11.4s, v11.4s, %[valpha].4s\n" + "fmin v12.4s, v12.4s, %[valpha].4s\n" + "fmin v13.4s, v13.4s, %[valpha].4s\n" + "fmin v14.4s, v14.4s, %[valpha].4s\n" + "fmin v15.4s, v15.4s, %[valpha].4s\n" + "b 10f\n" + "7: \n" + "fcmge v2.4s, v8.4s, %[vzero].4s\n" + "fmul v3.4s, v8.4s, %[valpha].4s\n" + "fcmge v4.4s, v9.4s, %[vzero].4s\n" + "fmul v5.4s, v9.4s, %[valpha].4s\n" + "fcmge v6.4s, v10.4s, %[vzero].4s\n" + "fmul v7.4s, v10.4s, %[valpha].4s\n" + "fcmge v0.4s, v11.4s, %[vzero].4s\n" + "fmul v1.4s, v11.4s, %[valpha].4s\n" + "bif v8.16b, v3.16b, v2.16b \n" + "bif v9.16b, v5.16b, v4.16b \n" + "bif v10.16b, v7.16b, v6.16b \n" + "bif v11.16b, v1.16b, v0.16b \n" + "fcmge v2.4s, v12.4s, %[vzero].4s\n" + "fmul v3.4s, v12.4s, %[valpha].4s\n" + "fcmge v4.4s, v13.4s, %[vzero].4s\n" + "fmul v5.4s, v13.4s, v1.4s \n" + "fcmge v6.4s, v14.4s, %[vzero].4s\n" + "fmul v7.4s, v14.4s, %[valpha].4s\n" + "fcmge v0.4s, v15.4s, %[vzero].4s\n" + "fmul v1.4s, v15.4s, %[valpha].4s\n" + "bif v12.16b, v3.16b, v2.16b \n" + "bif v13.16b, v5.16b, v4.16b \n" + "bif v14.16b, v7.16b, v6.16b \n" + "bif v15.16b, v1.16b, v0.16b \n" + "10: \n" + "st1 {v8.4s, v9.4s},[%[c_ptr0]], #32\n" + "st1 {v10.4s, v11.4s},[%[c_ptr1]], #32\n" + "st1 {v12.4s, v13.4s},[%[c_ptr2]], #32\n" + "st1 {v14.4s, v15.4s},[%[c_ptr3]], #32\n" + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k), + [tail] "+r"(tails) + : [bias_ptr] "r"(bias_local), + [beta] "r"(beta), + [alpha] "r"(alpha), + [flag_act] "r"(flag_act), + [vzero] "w"(vzero), + [valpha] "w"(valpha) + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15"); + // clang-format on + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} + void sgemm_prepacked_4x4(bool is_transB, int M, int N, @@ -3313,7 +4326,7 @@ void sgemm_prepacked_4x4(bool is_transB, "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */ "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ - + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */ "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5 */ @@ -3404,11 +4417,11 @@ void sgemm_prepacked_4x4(bool is_transB, "fmax v10.4s, v10.4s, v0.4s \n" /* relu*/ "fmax v11.4s, v11.4s, v0.4s \n" /* relu*/ "b 20f \n" /* relu end */ - //! no act + //! no act "12: \n" /* no relu */ "cmp %w[flag_act], #0 \n" /* check no act */ - "beq 20f \n" /* no act end */ - //! relu6 + "beq 20f \n" /* no act end */ + //! relu6 "cmp %w[flag_act], #2 \n" /* check if has relu6 */ "bne 13f \n" /* jump if no relu6 */ "movi v0.4s, #0 \n" /* for relu6 */ @@ -3427,17 +4440,17 @@ void sgemm_prepacked_4x4(bool is_transB, "13: \n" /* otherwise is leakey relu */ "movi v0.4s, #0 \n" /* for leakey relu */ "ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */ - "fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */ - "fcmge v12.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */ - "fmul v13.4s, v11.4s, v1.4s \n" /* vmulq_f32 */ - "bif v8.16b, v3.16b, v2.16b \n" /* choose*/ - "bif v9.16b, v5.16b, v4.16b \n" /* choose*/ - "bif v10.16b, v7.16b, v6.16b \n" /* choose*/ + "fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */ + "fcmge v12.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */ + "fmul v13.4s, v11.4s, v1.4s \n" /* vmulq_f32 */ + "bif v8.16b, v3.16b, v2.16b \n" /* choose*/ + "bif v9.16b, v5.16b, v4.16b \n" /* choose*/ + "bif v10.16b, v7.16b, v6.16b \n" /* choose*/ "bif v11.16b, v13.16b, v12.16b \n" /* choose*/ "20: \n" /* act end */ "st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */ @@ -3455,7 +4468,7 @@ void sgemm_prepacked_4x4(bool is_transB, [c_ptr3] "+r"(c_ptr3) : [bias_ptr] "r"(bias_local), [has_beta] "r"(has_beta), - [beta] "r"(beta), + [beta] "r"(beta), [alpha] "r"(alpha), [flag_act] "r"(flag_act) : "cc","memory", @@ -3926,7 +4939,7 @@ void sgemm_prepacked_6x8(bool is_transB, "cmp %[tails], #0 @ check no act\n" "beq 10f @ no act end \n" //! relu6 - "cmp %[tails], #2 @ check if has relu6\n" + "cmp %[tails], #2 @ check if has relu6\n" "bne 7f @ jump if no relu6 \n" "vmov.u32 q0, #0 @ for relu6\n" "vmax.f32 q4, q4, q0 @ for relu6\n" @@ -3957,45 +4970,45 @@ void sgemm_prepacked_6x8(bool is_transB, "vmin.f32 q15, q15, q1 @ for relu6\n" "b 10f @ relu6 end \n" //! leakey relu - "7: @ otherwise is leakey relu\n" + "7: @ otherwise is leakey relu\n" "vmov.u32 q0, #0 @ for leakey relu \n" "vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n" - "vcge.f32 q2, q4, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q4, q1 @ vmulq_f32 \n" - "vbif q4, q3, q2 @ choose \n" - "vcge.f32 q2, q5, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q5, q1 @ vmulq_f32 \n" + "vcge.f32 q2, q4, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q4, q1 @ vmulq_f32 \n" + "vbif q4, q3, q2 @ choose \n" + "vcge.f32 q2, q5, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q5, q1 @ vmulq_f32 \n" "vbif q5, q3, q2 @ choose \n" - "vcge.f32 q2, q6, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q6, q1 @ vmulq_f32 \n" - "vbif q6, q3, q2 @ choose \n" - "vcge.f32 q2, q7, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q7, q1 @ vmulq_f32 \n" - "vbif q7, q3, q2 @ choose \n" - "vcge.f32 q2, q8, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q8, q1 @ vmulq_f32 \n" - "vbif q8, q3, q2 @ choose \n" - "vcge.f32 q2, q9, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q9, q1 @ vmulq_f32 \n" + "vcge.f32 q2, q6, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q6, q1 @ vmulq_f32 \n" + "vbif q6, q3, q2 @ choose \n" + "vcge.f32 q2, q7, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q7, q1 @ vmulq_f32 \n" + "vbif q7, q3, q2 @ choose \n" + "vcge.f32 q2, q8, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q8, q1 @ vmulq_f32 \n" + "vbif q8, q3, q2 @ choose \n" + "vcge.f32 q2, q9, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q9, q1 @ vmulq_f32 \n" "vbif q9, q3, q2 @ choose \n" - "vcge.f32 q2, q10, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q10, q1 @ vmulq_f32 \n" - "vbif q10, q3, q2 @ choose \n" - "vcge.f32 q2, q11, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q11, q1 @ vmulq_f32 \n" - "vbif q11, q3, q2 @ choose \n" - "vcge.f32 q2, q12, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q12, q1 @ vmulq_f32 \n" - "vbif q12, q3, q2 @ choose \n" - "vcge.f32 q2, q13, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q13, q1 @ vmulq_f32 \n" - "vbif q13, q3, q2 @ choose \n" - "vcge.f32 q2, q14, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q14, q1 @ vmulq_f32 \n" - "vbif q14, q3, q2 @ choose \n" - "vcge.f32 q2, q15, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q15, q1 @ vmulq_f32 \n" - "vbif q15, q3, q2 @ choose \n" + "vcge.f32 q2, q10, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q10, q1 @ vmulq_f32 \n" + "vbif q10, q3, q2 @ choose \n" + "vcge.f32 q2, q11, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q11, q1 @ vmulq_f32 \n" + "vbif q11, q3, q2 @ choose \n" + "vcge.f32 q2, q12, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q12, q1 @ vmulq_f32 \n" + "vbif q12, q3, q2 @ choose \n" + "vcge.f32 q2, q13, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q13, q1 @ vmulq_f32 \n" + "vbif q13, q3, q2 @ choose \n" + "vcge.f32 q2, q14, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q14, q1 @ vmulq_f32 \n" + "vbif q14, q3, q2 @ choose \n" + "vcge.f32 q2, q15, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q15, q1 @ vmulq_f32 \n" + "vbif q15, q3, q2 @ choose \n" "10: @ act end \n" "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" @@ -4014,7 +5027,7 @@ void sgemm_prepacked_6x8(bool is_transB, [k] "+r"(k), [tails] "+r"(tails) : [bias_ptr] "r"(bias_local), - [beta] "r"(beta), + [beta] "r"(beta), [alpha] "r" (alpha) : "q0","q1","q2","q3","q4", "q5","q6","q7","q8","q9","q10","q11", @@ -4311,7 +5324,7 @@ void sgemm_prepacked_6x8_a53(bool is_transB, "vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */ "ldr r1, [%[b_ptr], #0x0C] \n" /* load b03 to r1 */ "vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */ - "subs %[k], %[k], #1 \n" /* loop k -= 1 */ + "subs %[k], %[k], #1 \n" /* loop k -= 1 */ "vldr d1, [%[a_ptr], #0x08] \n" /* load a20, a30 to d1 */ "vmov d5, r0, r1 \n" /* mov b02, b03 to d5 */ "vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */ @@ -4322,131 +5335,131 @@ void sgemm_prepacked_6x8_a53(bool is_transB, "bne 1b \n" /* branch to k loop */ "6:\n" "sub %[tails], %[tails], #4 \n" /* tail -= 4 */ - "cmp %[tails], #4 \n" /* cmp tail with 4 */ - "blt 3f \n" /* branch to tail == 1 */ + "cmp %[tails], #4 \n" /* cmp tail with 4 */ + "blt 3f \n" /* branch to tail == 1 */ /* Tail Unroll 0 */ "vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */ "add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */ - "vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */ + "vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */ "vld1.32 {d3}, [%[a_ptr] :64]! \n" /* load a01, a11 to d3 */ - "vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */ + "vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */ "add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */ - "vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */ + "vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */ "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */ - "vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */ - "vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */ + "vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */ + "vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */ "sub %[tails], %[tails], #4 \n" /* tail -= 4 */ - "vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */ + "vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */ "vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b10-b13 to d4,d5 */ - "vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */ - "vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */ - "vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */ - "vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */ + "vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */ + "vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */ + "vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */ + "vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */ "vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* load a21-a51 to d0,d1 */ "cmp %[tails], #4 \n" /* cmp tail with 4 */ - "vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */ - "vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */ - "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b14-b17 to d6,d7 */ - "blt 4f \n" /* branch to tail == 2 */ + "vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */ + "vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b14-b17 to d6,d7 */ + "blt 4f \n" /* branch to tail == 2 */ /* Tail Unroll 1 */ - "vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */ - "vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */ + "vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */ + "vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */ "sub %[tails], %[tails], #4 \n" /* tail -= 4 */ - "vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */ - "vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */ - "vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */ - "vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */ - "vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b20-b23 to d4,d5 */ - "vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */ - "vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */ + "vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */ + "vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */ + "vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */ + "vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */ + "vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b20-b23 to d4,d5 */ + "vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */ + "vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */ "cmp %[tails], #4 \n" /* cmp tail with 4 */ "vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a02-a32 to d2,d3 */ - "vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */ - "vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */ - "vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */ - "vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */ - "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b24-b27 to d6,d7 */ - "blt 5f \n" /* branch to tail == 3 */ + "vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */ + "vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */ + "vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */ + "vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b24-b27 to d6,d7 */ + "blt 5f \n" /* branch to tail == 3 */ /* Tail Unroll 2 */ "sub %[tails], %[tails], #4 \n" /* tail -= 4 */ "vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* a42a52a03a13 to d0,d1 */ - "vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */ - "vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */ + "vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */ + "vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */ "vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */ "vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */ "vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */ "vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */ "vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b30-b33 to d4,d5 */ - "vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */ - "vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */ - "vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */ - "vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */ + "vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */ + "vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */ + "vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */ + "vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */ "vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a23-a53 to d2,d3 */ - "vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */ - "vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */ + "vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */ + "vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */ "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b34-b37 to d6,d7 */ /* Tail Unroll 3 */ - "vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */ - "vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */ - "vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */ - "vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */ - "vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */ - "vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */ - "vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */ - "vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */ - "vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */ - "vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */ - "vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */ - "vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */ + "vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */ + "vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */ + "vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */ + "vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */ + "vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */ + "vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */ + "vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */ + "vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */ + "vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */ + "vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */ + "vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */ + "vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */ "b 2f \n" /* branch to check relu */ /* tails==1 final tail */ "3:\n" "vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */ "add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */ - "vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */ + "vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */ "add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */ - "vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */ + "vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */ "vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */ - "vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */ - "vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */ + "vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */ + "vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */ "vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */ "vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */ - "vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */ - "vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */ - "vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */ - "vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */ - "vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */ - "vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */ + "vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */ + "vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */ + "vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */ + "vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */ + "vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */ + "vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */ "b 2f \n" /* branch to check relu */ /* tails==2 final tail */ "4:\n" - "vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */ - "vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */ - "vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */ - "vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */ - "vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */ - "vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */ - "vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */ - "vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */ - "vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */ - "vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */ - "vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */ - "vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */ + "vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */ + "vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */ + "vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */ + "vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */ + "vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */ + "vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */ + "vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */ + "vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */ + "vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */ + "vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */ + "vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */ + "vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */ "b 2f \n" /* branch to check relu */ /* tails==3 final tail */ "5:\n" - "vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */ + "vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */ "vld1.32 {d0}, [%[a_ptr] :64]! \n" /* load a42, a52 to d0 */ - "vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */ - "vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */ - "vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */ + "vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */ + "vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */ + "vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */ "vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */ "vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */ - "vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */ + "vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */ "vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */ - "vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */ + "vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */ "vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */ - "vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */ + "vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */ "vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */ /* relu */ "2:\n" @@ -4838,7 +5851,7 @@ void sgemm_prepacked_4x8(bool is_transB, "cmp %[flag_act], #0 @ check no act\n" "beq 10f @ no act end \n" //! relu6 - "cmp %[flag_act], #2 @ check if has relu6\n" + "cmp %[flag_act], #2 @ check if has relu6\n" "bne 7f @ jump if no relu6 \n" "vmov.u32 q0, #0 @ for relu6\n" "vld1.f32 {d2-d3}, [%[alpha]] @ load relu6 alpha\n" @@ -4861,33 +5874,33 @@ void sgemm_prepacked_4x8(bool is_transB, "vmin.f32 q15, q15, q1 @ for relu6\n" "b 10f @ relu6 end \n" //! leakey relu - "7: @ otherwise is leakey relu\n" + "7: @ otherwise is leakey relu\n" "vmov.u32 q0, #0 @ for leakey relu \n" "vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n" - "vcge.f32 q2, q8, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q8, q1 @ vmulq_f32 \n" - "vbif q8, q3, q2 @ choose \n" - "vcge.f32 q2, q9, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q9, q1 @ vmulq_f32 \n" + "vcge.f32 q2, q8, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q8, q1 @ vmulq_f32 \n" + "vbif q8, q3, q2 @ choose \n" + "vcge.f32 q2, q9, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q9, q1 @ vmulq_f32 \n" "vbif q9, q3, q2 @ choose \n" - "vcge.f32 q2, q10, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q10, q1 @ vmulq_f32 \n" - "vbif q10, q3, q2 @ choose \n" - "vcge.f32 q2, q11, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q11, q1 @ vmulq_f32 \n" - "vbif q11, q3, q2 @ choose \n" - "vcge.f32 q2, q12, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q12, q1 @ vmulq_f32 \n" - "vbif q12, q3, q2 @ choose \n" - "vcge.f32 q2, q13, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q13, q1 @ vmulq_f32 \n" - "vbif q13, q3, q2 @ choose \n" - "vcge.f32 q2, q14, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q14, q1 @ vmulq_f32 \n" - "vbif q14, q3, q2 @ choose \n" - "vcge.f32 q2, q15, q0 @ vcgeq_u32 \n" - "vmul.f32 q3, q15, q1 @ vmulq_f32 \n" - "vbif q15, q3, q2 @ choose \n" + "vcge.f32 q2, q10, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q10, q1 @ vmulq_f32 \n" + "vbif q10, q3, q2 @ choose \n" + "vcge.f32 q2, q11, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q11, q1 @ vmulq_f32 \n" + "vbif q11, q3, q2 @ choose \n" + "vcge.f32 q2, q12, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q12, q1 @ vmulq_f32 \n" + "vbif q12, q3, q2 @ choose \n" + "vcge.f32 q2, q13, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q13, q1 @ vmulq_f32 \n" + "vbif q13, q3, q2 @ choose \n" + "vcge.f32 q2, q14, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q14, q1 @ vmulq_f32 \n" + "vbif q14, q3, q2 @ choose \n" + "vcge.f32 q2, q15, q0 @ vcgeq_u32 \n" + "vmul.f32 q3, q15, q1 @ vmulq_f32 \n" + "vbif q15, q3, q2 @ choose \n" "10: @ act end \n" "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" diff --git a/lite/backends/arm/math/sgemm.cc b/lite/backends/arm/math/sgemm.cc index f2ba090222..5929e6f6bd 100644 --- a/lite/backends/arm/math/sgemm.cc +++ b/lite/backends/arm/math/sgemm.cc @@ -38,9 +38,10 @@ void sgemm(bool is_transA, ARMContext* ctx) { int hblock = get_hblock(ctx); int m_roundup = hblock * ((M + hblock - 1) / hblock); + ctx->ExtendWorkspace(m_roundup * K * sizeof(float)); - auto packed_A = static_cast( - TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(float))); + auto packed_A = static_cast(ctx->workspace_data()) + + ctx->llc_size() / sizeof(float); prepackA(packed_A, A, alpha, lda, 0, M, 0, K, is_transA, ctx); @@ -58,7 +59,6 @@ void sgemm(bool is_transA, is_bias, act_param, ctx); - TargetFree(TargetType::kARM, packed_A); } } // namespace math diff --git a/lite/tests/math/sgemm_compute_test.cc b/lite/tests/math/sgemm_compute_test.cc index 8295cef341..941afb2702 100644 --- a/lite/tests/math/sgemm_compute_test.cc +++ b/lite/tests/math/sgemm_compute_test.cc @@ -39,7 +39,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); -DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(M, 512, "gemm: M"); -- GitLab