未验证 提交 47d21d28 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add v8-4x8 gemm implement (#4201)

* add v8 4x8 implment. test=devvelop

* fix run error

* change sgemm compute. test=develop

* fix format. test=develop
上级 03951c1a
......@@ -55,6 +55,39 @@ void sgemm_prepacked_8x12(bool is_transB,
const operators::ActivationParam act_param,
ARMContext *ctx);
void prepackA_4x8(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void prepackA_trans_4x8(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void sgemm_prepacked_4x8(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
const operators::ActivationParam act_param,
ARMContext *ctx);
void pack_m4(float *out,
const float *in,
float alpha,
......@@ -189,9 +222,9 @@ void prepackA(float *out,
#ifdef __aarch64__
if (mmax <= 4) {
if (is_trans) {
pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax);
} else {
pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
prepackA_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax);
}
} else {
if (is_trans) {
......@@ -269,7 +302,7 @@ void sgemm_prepack(bool is_transB,
ARMContext *ctx) {
#ifdef __aarch64__
if (M <= 4) {
sgemm_prepacked_4x4(is_transB,
sgemm_prepacked_4x8(is_transB,
M,
N,
K,
......@@ -633,6 +666,145 @@ void prepackA_8x12(float *dout,
}
}
}
void prepackA_4x8(float *outptr,
const float *inptr,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
int x_len = kmax - k0;
float zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(float) * x_len);
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
float32x4_t valpha = vdupq_n_f32(alpha);
#pragma omp parallel for
for (int y = m0; y < mmax; y += 4) {
const float *inptr0 = inptr + y * ldin + k0;
const float *inptr1 = inptr0 + ldin;
const float *inptr2 = inptr1 + ldin;
const float *inptr3 = inptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(inptr0),
[ptr1] "r"(inptr1),
[ptr2] "r"(inptr2),
[ptr3] "r"(inptr3)
: "memory");
int x = x_len;
if ((y + 3) >= mmax) {
switch ((y + 3) - mmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
// clang-format off
asm volatile(
"cbz %w[has_alpha], 0f\n"
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"fmul v4.4s, v4.4s, %[alpha].4s\n"
"fmul v5.4s, v5.4s, %[alpha].4s\n"
"fmul v6.4s, v6.4s, %[alpha].4s\n"
"fmul v7.4s, v7.4s, %[alpha].4s\n"
"b 1f\n" // to main process
"0: \n" // alpha == 1
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"1: \n"
"trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2
"trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3
"trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6
"trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7
"trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2
"trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3
"trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6
"trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7
"trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0
"trn1 v1.2d, v9.2d, v13.2d\n" // a1b1c1d1
"trn2 v2.2d, v8.2d, v12.2d\n" // a2b2c2d2
"trn2 v3.2d, v9.2d, v13.2d\n" // a3b3c3d3
"st1 {v0.4s}, [%[outptr]], #16\n"
"trn1 v4.2d, v10.2d, v14.2d\n" // a4b4c4d4
"st1 {v1.4s}, [%[outptr]], #16\n"
"trn1 v5.2d, v11.2d, v15.2d\n" // a5b5c5d5
"st1 {v2.4s}, [%[outptr]], #16\n"
"trn2 v6.2d, v10.2d, v14.2d\n" // a6b6c6d6
"st1 {v3.4s}, [%[outptr]], #16\n"
"trn2 v7.2d, v11.2d, v15.2d\n" // a7b7c7d7
"stp q4, q5, [%[outptr]], #32\n"
"stp q6, q7, [%[outptr]], #32\n"
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3),
[outptr] "+r"(outptr)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
// clang-format on
}
for (; x > 0; x--) {
if (has_alpha) {
*outptr++ = *inptr0++ * alpha;
*outptr++ = *inptr1++ * alpha;
*outptr++ = *inptr2++ * alpha;
*outptr++ = *inptr3++ * alpha;
} else {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
}
}
}
void pack_m4(float *dout,
const float *inptr,
float alpha,
......@@ -934,6 +1106,159 @@ void prepackA_trans_8x12(float *outptr,
}
}
}
void prepackA_trans_4x8(float *outptr,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
auto inptr = in + k0 * ldin + m0;
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
float32x4_t valpha = vdupq_n_f32(alpha);
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 4 * (x_len / 4);
int right_pad = 4 - right_remain;
if (right_remain == 0) {
right_pad = 0;
}
int stride_out = 4 * y_len;
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const float *ptr0 = inptr + y * ldin;
const float *ptr1 = ptr0 + ldin;
const float *ptr2 = ptr1 + ldin;
const float *ptr3 = ptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float *ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"ld1 {v1.4s}, [%[ptr1]], #16\n"
"ld1 {v2.4s}, [%[ptr2]], #16\n"
"ld1 {v3.4s}, [%[ptr3]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"0: \n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "v0", "v1", "v2", "v3", "cc", "memory");
// clang-format on
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float *ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"ld1 {v1.4s}, [%[ptr1]], #16\n"
"ld1 {v2.4s}, [%[ptr2]], #16\n"
"ld1 {v3.4s}, [%[ptr3]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"0: \n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask1].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1),
[vzero] "w"(vzero),
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "v0", "v1", "v2", "v3", "cc", "memory");
// clang-format on
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const float *ptr0 = inptr + y * ldin;
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float *ptr_out = outptr_row_col;
asm volatile(
"cmp %[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"0: \n"
"st1 {v0.4s}, [%[outptr]], #16\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "v0", "v1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float *ptr_out = outptr_row_col;
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"0: \n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"st1 {v0.4s}, [%[outptr]], #16\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1),
[vzero] "w"(vzero),
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "v0", "v1", "cc", "memory");
}
}
}
void pack_trans_m4(float *outptr,
const float *in,
float alpha,
......@@ -1587,6 +1912,7 @@ void prepackA_trans_4x8(float* outptr,
int i = 0;
for (; i < x_len - 3; i += 4) {
float* ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"cmp %[has_alpha], #0\n"
......@@ -1597,10 +1923,12 @@ void prepackA_trans_4x8(float* outptr,
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "q0", "q1", "cc", "memory");
// clang-format on
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float* ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"cmp %[has_alpha], #0\n"
......@@ -1615,6 +1943,7 @@ void prepackA_trans_4x8(float* outptr,
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "q0", "q1", "cc", "memory");
// clang-format on
}
}
}
......@@ -1624,7 +1953,7 @@ void prepackA_trans_4x8(float* outptr,
/**
* \brief input data is transpose
* for arm-v7a, transform data to block x k x 8 layout
* for arm-v8a, transform data to block x k x 12 layout
* for arm-v8a, transform data to block x k x 12 layout or block x k x 8 layout
*/
#ifdef __aarch64__
void loadb(
......@@ -1996,9 +2325,288 @@ void loadb_trans(
"zip1 v18.4s, v16.4s, v17.4s\n" /* i6j6k6l6 */
"zip2 v19.4s, v16.4s, v17.4s\n" /* i7j7k7l7 */
"str q18, [%[outptr]], #16\n" /* save i6~l6 */
"stp q22, q23, [%[outptr]], #32\n" /* save a7~h7 */
"str q19, [%[outptr]], #16\n" /* save i7~l7 */
"str q18, [%[outptr]], #16\n" /* save i6~l6 */
"stp q22, q23, [%[outptr]], #32\n" /* save a7~h7 */
"str q19, [%[outptr]], #16\n" /* save i7~l7 */
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3),
[inptr4] "+r"(inptr4),
[inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6),
[inptr7] "+r"(inptr7),
[inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9),
[inptr10] "+r"(inptr10),
[inptr11] "+r"(inptr11),
[outptr] "+r"(outptr)
:
: "v0","v1","v2","v3","v4","v5",
"v6","v7","v8","v9","v10","v11","v12",
"v13","v14","v15","v16","v17","v18","v19",
"v20","v21","v22","v23","v24","v25","v26",
"v27","v28","v29","v30","v31","cc","memory");
// clang-format on
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
*outptr++ = *inptr8++;
*outptr++ = *inptr9++;
*outptr++ = *inptr10++;
*outptr++ = *inptr11++;
}
}
}
void loadb_eight(
float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) {
auto outptr = reinterpret_cast<uint32_t *>(out);
auto inptr = reinterpret_cast<const uint32_t *>(in) + k0 * ldin + n0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = nmax - n0;
int y_len = kmax - k0;
int right_remain = x_len - 8 * (x_len / 8);
int right_pad = 8 - right_remain;
uint32_t *outptr_row = outptr;
int stride_out = 8 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t *ptr0 = inptr + y * ldin;
const uint32_t *ptr1 = ptr0 + ldin;
const uint32_t *ptr2 = ptr1 + ldin;
const uint32_t *ptr3 = ptr2 + ldin;
uint32_t *outptr_row_col = outptr_row + y * 8;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n" // load r0, 8 elements
"ldp q2, q3, [%[ptr1]], #32\n" // load r1, 8 elements
"stp q0, q1, [%[outptr]], #32\n" // write to output ptr
"stp q2, q3, [%[outptr]], #32\n" // write to output ptr
"ldp q0, q1, [%[ptr2]], #32\n" // load r0, 8 elements
"ldp q2, q3, [%[ptr3]], #32\n" // load r1, 8 elements
"stp q0, q1, [%[outptr]], #32\n" // write to output ptr
"stp q2, q3, [%[outptr]], #32\n" // write to output ptr
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
:
: "v0", "v1", "v2", "v3", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"ldp q2, q3, [%[ptr1]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"ldp q0, q1, [%[ptr2]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
"ldp q2, q3, [%[ptr3]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "v0", "v1", "v2", "v3", "cc", "memory");
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t *ptr0 = inptr + y * ldin;
uint32_t *outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"stp q0, q1, [%[outptr]], #32\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
:
: "v0", "v1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "v0", "v1", "cc", "memory");
}
}
}
void loadb_trans_eight(
float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
auto outptr = reinterpret_cast<uint32_t *>(out);
auto inptr = reinterpret_cast<const uint32_t *>(in);
//! data B is not transposed, transpose B to k * 8
for (int y = n0; y < nmax; y += 8) {
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
const uint32_t *inptr3 = inptr2 + ldin;
const uint32_t *inptr4 = inptr3 + ldin;
const uint32_t *inptr5 = inptr4 + ldin;
const uint32_t *inptr6 = inptr5 + ldin;
const uint32_t *inptr7 = inptr6 + ldin;
int x = x_len;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
"prfm pldl1keep, [%[ptr4]] \n"
"prfm pldl1keep, [%[ptr4], #64] \n"
"prfm pldl1keep, [%[ptr5]] \n"
"prfm pldl1keep, [%[ptr5], #64] \n"
"prfm pldl1keep, [%[ptr6]] \n"
"prfm pldl1keep, [%[ptr6], #64] \n"
"prfm pldl1keep, [%[ptr7]] \n"
"prfm pldl1keep, [%[ptr7], #64] \n"
:
: [ptr0] "r"(inptr0),
[ptr1] "r"(inptr1),
[ptr2] "r"(inptr2),
[ptr3] "r"(inptr3),
[ptr4] "r"(inptr4),
[ptr5] "r"(inptr5),
[ptr6] "r"(inptr6),
[ptr7] "r"(inptr7)
: "memory");
//! cope with row index exceed real size, set to zero buffer
if ((y + 7) >= nmax) {
switch ((y + 7) - nmax) {
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
// clang-format off
//! zip load 8 elements (2 neon Q registers) from each of 8 rows
asm volatile(
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2
"trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3
"trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6
"trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7
"ldp q16, q17, [%[inptr4]], #32\n"// load r4, e0~e7
"ldp q18, q19, [%[inptr5]], #32\n"// load r5, f0~f7
"ldp q20, q21, [%[inptr6]], #32\n"// load r6, g0~g7
"ldp q22, q23, [%[inptr7]], #32\n"// load r7, h0~h7
"trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2
"trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3
"trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6
"trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7
"trn1 v24.4s, v16.4s, v18.4s\n" // e0f0e2f2
"trn2 v25.4s, v16.4s, v18.4s\n" // e1f1e3f3
"trn1 v28.4s, v20.4s, v22.4s\n" // g0h0e2f2
"trn2 v29.4s, v20.4s, v22.4s\n" // g1h1e3f3
"trn1 v26.4s, v17.4s, v19.4s\n" // e4f4e6f6
"trn2 v27.4s, v17.4s, v19.4s\n" // e5f5e7f7
"trn1 v30.4s, v21.4s, v23.4s\n" // g4h4e6f6
"trn2 v31.4s, v21.4s, v23.4s\n" // g5h5e7f7
"trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0
"trn1 v1.2d, v24.2d, v28.2d\n" // e0f0g0h0
"trn1 v2.2d, v9.2d, v13.2d\n" // a1b1c1d1
"trn1 v3.2d, v25.2d, v29.2d\n" // e1f1g1h1
"trn2 v4.2d, v8.2d, v12.2d\n" // a2b2c2d2
"trn2 v5.2d, v24.2d, v28.2d\n" // e2f2g2h2
"stp q0, q1, [%[outptr]], #32\n" // save q0, q1, a0~h0
"trn2 v6.2d, v9.2d, v13.2d\n" // a3b3c3d3
"trn2 v7.2d, v25.2d, v29.2d\n" // e3f3g3h3
"stp q2, q3, [%[outptr]], #32\n" // save q0, q1, a1~h1
"trn1 v16.2d, v10.2d, v14.2d\n" // a4b4c4d4
"trn1 v17.2d, v26.2d, v30.2d\n" // e4f4g4h4
"stp q4, q5, [%[outptr]], #32\n" // save q0, q1, a2~h2
"trn1 v18.2d, v11.2d, v15.2d\n" // a5b5c5d5
"trn1 v19.2d, v27.2d, v31.2d\n" // e5f5g5h5
"stp q6, q7, [%[outptr]], #32\n" // save q0, q1, a3~h3
"trn2 v20.2d, v10.2d, v14.2d\n" // a6b6c6d6
"trn2 v21.2d, v26.2d, v30.2d\n" // e6f6g6h6
"stp q16, q17, [%[outptr]], #32\n" // save q0, q1, a4~h4
"trn2 v22.2d, v11.2d, v15.2d\n" // a7b7c7d7
"trn2 v23.2d, v27.2d, v31.2d\n" // e7f7g7h7
"stp q18, q19, [%[outptr]], #32\n" // save q0, q1, a5~h5
"stp q20, q21, [%[outptr]], #32\n" // save q0, q1, a6~h6
"stp q22, q23, [%[outptr]], #32\n" // save q0, q1, a7~h7
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
......@@ -2007,10 +2615,6 @@ void loadb_trans(
[inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6),
[inptr7] "+r"(inptr7),
[inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9),
[inptr10] "+r"(inptr10),
[inptr11] "+r"(inptr11),
[outptr] "+r"(outptr)
:
: "v0","v1","v2","v3","v4","v5",
......@@ -2030,10 +2634,6 @@ void loadb_trans(
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
*outptr++ = *inptr8++;
*outptr++ = *inptr9++;
*outptr++ = *inptr10++;
*outptr++ = *inptr11++;
}
}
}
......@@ -2943,11 +3543,11 @@ void sgemm_prepacked_8x12(bool is_transB,
"fmax v30.4s, v30.4s, v0.4s \n" /* relu*/
"fmax v31.4s, v31.4s, v0.4s \n" /* relu*/
"b 20f \n" /* relu end */
//! no act
//! no act
"12: \n" /* no relu */
"cmp %w[flag_act], #0 \n" /* check no act */
"beq 20f \n" /* no act end */
//! relu6
"beq 20f \n" /* no act end */
//! relu6
"cmp %w[flag_act], #2 \n" /* check if has relu6 */
"bne 13f \n" /* jump if no relu6 */
"movi v0.4s, #0 \n" /* for relu6 */
......@@ -3005,77 +3605,77 @@ void sgemm_prepacked_8x12(bool is_transB,
"13: \n" /* otherwise is leakey relu */
"movi v0.4s, #0 \n" /* for leakey relu */
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"bif v10.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"bif v11.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v12.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v12.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v13.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v13.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v14.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v14.4s, v1.4s \n" /* vmulq_f32 */
"bif v12.16b, v3.16b, v2.16b \n" /* choose*/
"bif v13.16b, v5.16b, v4.16b \n" /* choose*/
"fcmge v2.4s, v12.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v12.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v13.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v13.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v14.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v14.4s, v1.4s \n" /* vmulq_f32 */
"bif v12.16b, v3.16b, v2.16b \n" /* choose*/
"bif v13.16b, v5.16b, v4.16b \n" /* choose*/
"bif v14.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v15.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v15.4s, v1.4s \n" /* vmulq_f32 */
"bif v15.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */
"bif v16.16b, v3.16b, v2.16b \n" /* choose*/
"bif v17.16b, v5.16b, v4.16b \n" /* choose*/
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */
"bif v16.16b, v3.16b, v2.16b \n" /* choose*/
"bif v17.16b, v5.16b, v4.16b \n" /* choose*/
"bif v18.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v19.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v19.4s, v1.4s \n" /* vmulq_f32 */
"bif v19.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */
"bif v20.16b, v3.16b, v2.16b \n" /* choose*/
"bif v21.16b, v5.16b, v4.16b \n" /* choose*/
"bif v22.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v23.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */
"bif v20.16b, v3.16b, v2.16b \n" /* choose*/
"bif v21.16b, v5.16b, v4.16b \n" /* choose*/
"bif v22.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v23.4s, v1.4s \n" /* vmulq_f32 */
"bif v23.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */
"bif v24.16b, v3.16b, v2.16b \n" /* choose*/
"bif v25.16b, v5.16b, v4.16b \n" /* choose*/
"fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */
"bif v24.16b, v3.16b, v2.16b \n" /* choose*/
"bif v25.16b, v5.16b, v4.16b \n" /* choose*/
"bif v26.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v27.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v27.4s, v1.4s \n" /* vmulq_f32 */
"bif v27.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */
"bif v28.16b, v3.16b, v2.16b \n" /* choose*/
"bif v29.16b, v5.16b, v4.16b \n" /* choose*/
"fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */
"bif v28.16b, v3.16b, v2.16b \n" /* choose*/
"bif v29.16b, v5.16b, v4.16b \n" /* choose*/
"bif v30.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v31.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v2.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v31.4s, v1.4s \n" /* vmulq_f32 */
"bif v31.16b, v3.16b, v2.16b \n" /* choose*/
"20: \n" /* act end */
......@@ -3103,7 +3703,7 @@ void sgemm_prepacked_8x12(bool is_transB,
: [bias_ptr] "r"(bias_local),
[has_beta] "r"(has_beta),
[beta] "r"(beta),
[alpha] "r"(alpha),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
......@@ -3129,6 +3729,419 @@ void sgemm_prepacked_8x12(bool is_transB,
}
}
void sgemm_prepacked_4x8(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
const operators::ActivationParam act_param,
ARMContext *ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto *workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leaky: 4
const int n_block = 8;
const int m_block = 4;
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
float32x4_t valpha = vld1q_f32(alpha);
float32x4_t vzero = vdupq_n_f32(0.f);
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block));
x_block /= n_block;
x_block *= n_block;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + n_block - 1) / n_block;
x_block *= n_block;
x_block = x_block < n_block ? n_block : x_block;
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
int tail_pre = (K & (KBLOCK - 1));
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
bool flag_p_remain = false;
int remain = 0;
int has_beta = fabsf(beta) > 1e-8f ? 1 : 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + n_block - 1) / n_block;
remain = xmax - x0 - (bblocks - 1) * n_block;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
auto b_pannel = static_cast<float *>(workspace);
if (is_transB) {
loadb_trans_eight(b_pannel, B, ldb, 0, K, x0, xmax);
} else {
loadb_eight(b_pannel, B, ldb, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += m_block) {
unsigned int ymax = y + m_block;
if (ymax > M) {
ymax = M;
}
float cout0[n_block]; // NOLINT
float cout1[n_block]; // NOLINT
float cout2[n_block]; // NOLINT;
float cout3[n_block]; // NOLINT
float bias_local[4] = {0};
if (has_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
}
float *c_ptr0 = C + y * ldc + x0;
float *c_ptr1 = c_ptr0 + ldc;
float *c_ptr2 = c_ptr1 + ldc;
float *c_ptr3 = c_ptr2 + ldc;
float *pout0 = c_ptr0;
float *pout1 = c_ptr1;
float *pout2 = c_ptr2;
float *pout3 = c_ptr3;
const float *a_ptr_l = A_packed + y * K;
const float *b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
c_ptr1 = cout1;
case 1:
c_ptr2 = cout1;
case 0:
c_ptr3 = cout1;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
if (has_beta) {
for (int i = 0; i < remain; ++i) {
cout0[i] = pout0[i];
cout1[i] = pout1[i];
cout2[i] = pout2[i];
cout3[i] = pout3[i];
}
}
}
const float *a_ptr = a_ptr_l;
int tails = tail_pre;
int k = k_pre;
// clang-format off
asm volatile(
"ld1 {v2.4s}, [%[bias_ptr]]\n"
"dup v8.4s, v2.s[0]\n"
"prfm pldl1keep, [%[a_ptr]]\n"
"dup v9.4s, v2.s[0]\n"
"prfm pldl1keep, [%[b_ptr]]\n"
"dup v10.4s, v2.s[1]\n"
"dup v11.4s, v2.s[1]\n"
"prfm pldl1keep, [%[a_ptr], #64]\n"
"dup v12.4s, v2.s[2]\n"
"dup v13.4s, v2.s[2]\n"
"prfm pldl1keep, [%[b_ptr], #64]\n"
"dup v14.4s, v2.s[3]\n"
"dup v15.4s, v2.s[3]\n"
"prfm pldl1keep, [%[a_ptr], #128]\n"
"cmp %w[beta], #0\n" // check beta == 0?
"prfm pldl1keep, [%[b_ptr], #128]\n"
"prfm pldl1keep, [%[b_ptr], #192]\n"
// process beta
"beq 11f\n"
"dup v7.4s, %w[beta]\n" // beta to vector
"ld1 {v0.4s, v1.4s}, [%[c_ptr0]]\n" // load output r0
"ld1 {v2.4s, v3.4s}, [%[c_ptr1]]\n" // load output r1
"fmla v8.4s, v0.4s, v7.4s\n" // cr00 += beta * c_r00
"fmla v9.4s, v1.4s, v7.4s\n" // cr01 += beta * c_r01
"ld1 {v0.4s, v1.4s}, [%[c_ptr2]]\n" // load output r2
"fmla v10.4s, v2.4s, v7.4s\n" // cr10 += beta * c_r10
"fmla v11.4s, v3.4s, v7.4s\n" // cr11 += beta * c_r11
"ld1 {v2.4s, v3.4s}, [%[c_ptr3]]\n" // load output r3
"fmla v12.4s, v0.4s, v7.4s\n" // cr20 += beta * c_r20
"fmla v13.4s, v1.4s, v7.4s\n" // cr21 += beta * c_r21
"fmla v14.4s, v2.4s, v7.4s\n" // cr30 += beta * c_r30
"fmla v15.4s, v3.4s, v7.4s\n" // cr31 += beta * c_r31
"11: \n" // check loop count
"ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
"cbz %w[k], 0f\n" // check loop count > 0
// main loop
// Unroll 0
"1: \n"
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
// Unroll 1
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"prfm pldl1keep, [%[b_ptr], #192]\n"
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
// Unroll 2
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
// Unroll 3
"fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0]
"prfm pldl1keep, [%[a_ptr], #128]\n"
"fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3]
"subs %w[k], %w[k], #1\n" // loop count - 1
"fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3]
"bne 1b\n"
"0: \n"
"subs %w[tail], %w[tail], #1\n" // tail--
"beq 3f\n" // jump to tail = 1
// Unroll 0
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"subs %w[tail], %w[tail], #1\n" // tail--
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
"beq 4f\n" // jump to tail = 2
// Unroll 1
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"subs %w[tail], %w[tail], #1\n" // tail--*/
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"beq 5f\n" // jump to tail = 3
// Unroll 2
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
// Unroll 3
"fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3]
"b 2f\n"
// tails==1 final tail
"3: \n"
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
// aptr - 16
"sub %w[a_ptr], %w[a_ptr], #16\n"
"b 2f\n"
"4: \n"
// tails==2 final tail
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"b 2f\n"
// tails==3 final tail
"5: \n"
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
// aptr - 16
"sub %w[a_ptr], %w[a_ptr], #16\n"
"2: \n"
"cmp %w[flag_act], #0\n" // check if has act
"beq 10f\n"
"cmp %w[flag_act], #1\n" // check if has relu
"bne 6f\n"
"fmax v8.4s, v8.4s, %[vzero].4s\n"
"fmax v9.4s, v9.4s, %[vzero].4s\n"
"fmax v10.4s, v10.4s, %[vzero].4s\n"
"fmax v11.4s, v11.4s, %[vzero].4s\n"
"fmax v12.4s, v12.4s, %[vzero].4s\n"
"fmax v13.4s, v13.4s, %[vzero].4s\n"
"fmax v14.4s, v14.4s, %[vzero].4s\n"
"fmax v15.4s, v15.4s, %[vzero].4s\n"
"b 10f\n" // end
"6: \n"
"cmp %w[flag_act], #2\n" // check relu6
"bne 7f\n"
"fmax v8.4s, v8.4s, %[vzero].4s\n"
"fmax v9.4s, v9.4s, %[vzero].4s\n"
"fmax v10.4s, v10.4s, %[vzero].4s\n"
"fmax v11.4s, v11.4s, %[vzero].4s\n"
"fmax v12.4s, v12.4s, %[vzero].4s\n"
"fmax v13.4s, v13.4s, %[vzero].4s\n"
"fmax v14.4s, v14.4s, %[vzero].4s\n"
"fmax v15.4s, v15.4s, %[vzero].4s\n"
"fmin v8.4s, v8.4s, %[valpha].4s\n"
"fmin v9.4s, v9.4s, %[valpha].4s\n"
"fmin v10.4s, v10.4s, %[valpha].4s\n"
"fmin v11.4s, v11.4s, %[valpha].4s\n"
"fmin v12.4s, v12.4s, %[valpha].4s\n"
"fmin v13.4s, v13.4s, %[valpha].4s\n"
"fmin v14.4s, v14.4s, %[valpha].4s\n"
"fmin v15.4s, v15.4s, %[valpha].4s\n"
"b 10f\n"
"7: \n"
"fcmge v2.4s, v8.4s, %[vzero].4s\n"
"fmul v3.4s, v8.4s, %[valpha].4s\n"
"fcmge v4.4s, v9.4s, %[vzero].4s\n"
"fmul v5.4s, v9.4s, %[valpha].4s\n"
"fcmge v6.4s, v10.4s, %[vzero].4s\n"
"fmul v7.4s, v10.4s, %[valpha].4s\n"
"fcmge v0.4s, v11.4s, %[vzero].4s\n"
"fmul v1.4s, v11.4s, %[valpha].4s\n"
"bif v8.16b, v3.16b, v2.16b \n"
"bif v9.16b, v5.16b, v4.16b \n"
"bif v10.16b, v7.16b, v6.16b \n"
"bif v11.16b, v1.16b, v0.16b \n"
"fcmge v2.4s, v12.4s, %[vzero].4s\n"
"fmul v3.4s, v12.4s, %[valpha].4s\n"
"fcmge v4.4s, v13.4s, %[vzero].4s\n"
"fmul v5.4s, v13.4s, v1.4s \n"
"fcmge v6.4s, v14.4s, %[vzero].4s\n"
"fmul v7.4s, v14.4s, %[valpha].4s\n"
"fcmge v0.4s, v15.4s, %[vzero].4s\n"
"fmul v1.4s, v15.4s, %[valpha].4s\n"
"bif v12.16b, v3.16b, v2.16b \n"
"bif v13.16b, v5.16b, v4.16b \n"
"bif v14.16b, v7.16b, v6.16b \n"
"bif v15.16b, v1.16b, v0.16b \n"
"10: \n"
"st1 {v8.4s, v9.4s},[%[c_ptr0]], #32\n"
"st1 {v10.4s, v11.4s},[%[c_ptr1]], #32\n"
"st1 {v12.4s, v13.4s},[%[c_ptr2]], #32\n"
"st1 {v14.4s, v15.4s},[%[c_ptr3]], #32\n"
: [a_ptr] "+r"(a_ptr),
[b_ptr] "+r"(b_ptr),
[c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1),
[c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3),
[k] "+r"(k),
[tail] "+r"(tails)
: [bias_ptr] "r"(bias_local),
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12","v13",
"v14","v15");
// clang-format on
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
}
}
}
}
}
}
void sgemm_prepacked_4x4(bool is_transB,
int M,
int N,
......@@ -3313,7 +4326,7 @@ void sgemm_prepacked_4x4(bool is_transB,
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */
"fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */
"fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5 */
......@@ -3404,11 +4417,11 @@ void sgemm_prepacked_4x4(bool is_transB,
"fmax v10.4s, v10.4s, v0.4s \n" /* relu*/
"fmax v11.4s, v11.4s, v0.4s \n" /* relu*/
"b 20f \n" /* relu end */
//! no act
//! no act
"12: \n" /* no relu */
"cmp %w[flag_act], #0 \n" /* check no act */
"beq 20f \n" /* no act end */
//! relu6
"beq 20f \n" /* no act end */
//! relu6
"cmp %w[flag_act], #2 \n" /* check if has relu6 */
"bne 13f \n" /* jump if no relu6 */
"movi v0.4s, #0 \n" /* for relu6 */
......@@ -3427,17 +4440,17 @@ void sgemm_prepacked_4x4(bool is_transB,
"13: \n" /* otherwise is leakey relu */
"movi v0.4s, #0 \n" /* for leakey relu */
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v12.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v13.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"bif v10.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v12.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v13.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"bif v10.16b, v7.16b, v6.16b \n" /* choose*/
"bif v11.16b, v13.16b, v12.16b \n" /* choose*/
"20: \n" /* act end */
"st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */
......@@ -3455,7 +4468,7 @@ void sgemm_prepacked_4x4(bool is_transB,
[c_ptr3] "+r"(c_ptr3)
: [bias_ptr] "r"(bias_local),
[has_beta] "r"(has_beta),
[beta] "r"(beta),
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act)
: "cc","memory",
......@@ -3926,7 +4939,7 @@ void sgemm_prepacked_6x8(bool is_transB,
"cmp %[tails], #0 @ check no act\n"
"beq 10f @ no act end \n"
//! relu6
"cmp %[tails], #2 @ check if has relu6\n"
"cmp %[tails], #2 @ check if has relu6\n"
"bne 7f @ jump if no relu6 \n"
"vmov.u32 q0, #0 @ for relu6\n"
"vmax.f32 q4, q4, q0 @ for relu6\n"
......@@ -3957,45 +4970,45 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmin.f32 q15, q15, q1 @ for relu6\n"
"b 10f @ relu6 end \n"
//! leakey relu
"7: @ otherwise is leakey relu\n"
"7: @ otherwise is leakey relu\n"
"vmov.u32 q0, #0 @ for leakey relu \n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n"
"vcge.f32 q2, q4, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q4, q1 @ vmulq_f32 \n"
"vbif q4, q3, q2 @ choose \n"
"vcge.f32 q2, q5, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q5, q1 @ vmulq_f32 \n"
"vcge.f32 q2, q4, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q4, q1 @ vmulq_f32 \n"
"vbif q4, q3, q2 @ choose \n"
"vcge.f32 q2, q5, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q5, q1 @ vmulq_f32 \n"
"vbif q5, q3, q2 @ choose \n"
"vcge.f32 q2, q6, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q6, q1 @ vmulq_f32 \n"
"vbif q6, q3, q2 @ choose \n"
"vcge.f32 q2, q7, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q7, q1 @ vmulq_f32 \n"
"vbif q7, q3, q2 @ choose \n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vcge.f32 q2, q6, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q6, q1 @ vmulq_f32 \n"
"vbif q6, q3, q2 @ choose \n"
"vcge.f32 q2, q7, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q7, q1 @ vmulq_f32 \n"
"vbif q7, q3, q2 @ choose \n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vbif q9, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"10: @ act end \n"
"vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n"
......@@ -4014,7 +5027,7 @@ void sgemm_prepacked_6x8(bool is_transB,
[k] "+r"(k),
[tails] "+r"(tails)
: [bias_ptr] "r"(bias_local),
[beta] "r"(beta),
[beta] "r"(beta),
[alpha] "r" (alpha)
: "q0","q1","q2","q3","q4",
"q5","q6","q7","q8","q9","q10","q11",
......@@ -4311,7 +5324,7 @@ void sgemm_prepacked_6x8_a53(bool is_transB,
"vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */
"ldr r1, [%[b_ptr], #0x0C] \n" /* load b03 to r1 */
"vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */
"subs %[k], %[k], #1 \n" /* loop k -= 1 */
"subs %[k], %[k], #1 \n" /* loop k -= 1 */
"vldr d1, [%[a_ptr], #0x08] \n" /* load a20, a30 to d1 */
"vmov d5, r0, r1 \n" /* mov b02, b03 to d5 */
"vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */
......@@ -4322,131 +5335,131 @@ void sgemm_prepacked_6x8_a53(bool is_transB,
"bne 1b \n" /* branch to k loop */
"6:\n"
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"blt 3f \n" /* branch to tail == 1 */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"blt 3f \n" /* branch to tail == 1 */
/* Tail Unroll 0 */
"vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */
"add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"vld1.32 {d3}, [%[a_ptr] :64]! \n" /* load a01, a11 to d3 */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b10-b13 to d4,d5 */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* load a21-a51 to d0,d1 */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b14-b17 to d6,d7 */
"blt 4f \n" /* branch to tail == 2 */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b14-b17 to d6,d7 */
"blt 4f \n" /* branch to tail == 2 */
/* Tail Unroll 1 */
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b20-b23 to d4,d5 */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b20-b23 to d4,d5 */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a02-a32 to d2,d3 */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b24-b27 to d6,d7 */
"blt 5f \n" /* branch to tail == 3 */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b24-b27 to d6,d7 */
"blt 5f \n" /* branch to tail == 3 */
/* Tail Unroll 2 */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* a42a52a03a13 to d0,d1 */
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b30-b33 to d4,d5 */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a23-a53 to d2,d3 */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b34-b37 to d6,d7 */
/* Tail Unroll 3 */
"vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */
"vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */
"vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */
"vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */
"vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */
"vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */
"vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */
"vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */
"vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */
"vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */
"vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */
"vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */
"vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */
"vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */
"vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */
"vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */
"vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */
"vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */
"vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */
"vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */
"vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */
"vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */
"vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */
"vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */
"b 2f \n" /* branch to check relu */
/* tails==1 final tail */
"3:\n"
"vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */
"add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"b 2f \n" /* branch to check relu */
/* tails==2 final tail */
"4:\n"
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"b 2f \n" /* branch to check relu */
/* tails==3 final tail */
"5:\n"
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vld1.32 {d0}, [%[a_ptr] :64]! \n" /* load a42, a52 to d0 */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
/* relu */
"2:\n"
......@@ -4838,7 +5851,7 @@ void sgemm_prepacked_4x8(bool is_transB,
"cmp %[flag_act], #0 @ check no act\n"
"beq 10f @ no act end \n"
//! relu6
"cmp %[flag_act], #2 @ check if has relu6\n"
"cmp %[flag_act], #2 @ check if has relu6\n"
"bne 7f @ jump if no relu6 \n"
"vmov.u32 q0, #0 @ for relu6\n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load relu6 alpha\n"
......@@ -4861,33 +5874,33 @@ void sgemm_prepacked_4x8(bool is_transB,
"vmin.f32 q15, q15, q1 @ for relu6\n"
"b 10f @ relu6 end \n"
//! leakey relu
"7: @ otherwise is leakey relu\n"
"7: @ otherwise is leakey relu\n"
"vmov.u32 q0, #0 @ for leakey relu \n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vbif q9, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"10: @ act end \n"
"vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n"
......
......@@ -38,9 +38,10 @@ void sgemm(bool is_transA,
ARMContext* ctx) {
int hblock = get_hblock(ctx);
int m_roundup = hblock * ((M + hblock - 1) / hblock);
ctx->ExtendWorkspace(m_roundup * K * sizeof(float));
auto packed_A = static_cast<float*>(
TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(float)));
auto packed_A = static_cast<float*>(ctx->workspace_data<float>()) +
ctx->llc_size() / sizeof(float);
prepackA(packed_A, A, alpha, lda, 0, M, 0, K, is_transA, ctx);
......@@ -58,7 +59,6 @@ void sgemm(bool is_transA,
is_bias,
act_param,
ctx);
TargetFree(TargetType::kARM, packed_A);
}
} // namespace math
......
......@@ -39,7 +39,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm: M");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册