未验证 提交 1ebac1c0 编写于 作者: T TianXiaogang 提交者: GitHub

Armv8 4x4 gemm (#2528)

* feat: add sgemm4x4 for armv8

* fix: fix armv7 gemm choose condition
上级 04d2b4eb
...@@ -53,6 +53,38 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -53,6 +53,38 @@ void sgemm_prepacked_8x12(bool is_transB,
bool has_bias, bool has_bias,
bool has_relu, bool has_relu,
ARMContext *ctx); ARMContext *ctx);
void pack_m4(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void pack_trans_m4(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void sgemm_prepacked_4x4(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
bool has_relu,
ARMContext *ctx);
#else #else
// for kA72 // for kA72
void prepackA_6x8(float *out, void prepackA_6x8(float *out,
...@@ -139,13 +171,21 @@ void prepackA(float *out, ...@@ -139,13 +171,21 @@ void prepackA(float *out,
bool is_trans, bool is_trans,
ARMContext *ctx) { ARMContext *ctx) {
#ifdef __aarch64__ #ifdef __aarch64__
if (is_trans) { if (mmax <= 4) {
prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); if (is_trans) {
pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
} else {
pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
}
} else { } else {
prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); if (is_trans) {
prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax);
} else {
prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax);
}
} }
#else #else
if (ctx->arch() == kA73) { if (ctx->arch() == kA73 || mmax <= 4) {
if (is_trans) { if (is_trans) {
prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax);
} else { } else {
...@@ -212,22 +252,39 @@ void sgemm_prepack(bool is_transB, ...@@ -212,22 +252,39 @@ void sgemm_prepack(bool is_transB,
bool has_relu, bool has_relu,
ARMContext *ctx) { ARMContext *ctx) {
#ifdef __aarch64__ #ifdef __aarch64__
sgemm_prepacked_8x12(is_transB, if (M <= 4) {
M, sgemm_prepacked_4x4(is_transB,
N, M,
K, N,
A_packed, K,
B, A_packed,
ldb, B,
beta, ldb,
C, beta,
ldc, C,
bias, ldc,
has_bias, bias,
has_relu, has_bias,
ctx); has_relu,
ctx);
} else {
sgemm_prepacked_8x12(is_transB,
M,
N,
K,
A_packed,
B,
ldb,
beta,
C,
ldc,
bias,
has_bias,
has_relu,
ctx);
}
#else // armv7 #else // armv7
if (ctx->arch() == kA73) { if (ctx->arch() == kA73 || M <= 4) {
sgemm_prepacked_4x8(is_transB, sgemm_prepacked_4x8(is_transB,
M, M,
N, N,
...@@ -522,6 +579,147 @@ void prepackA_8x12(float *dout, ...@@ -522,6 +579,147 @@ void prepackA_8x12(float *dout,
} }
} }
} }
void pack_m4(float *dout,
const float *inptr,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
int x_len = kmax - k0;
int stride = x_len * 4;
float zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(float) * x_len);
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
#pragma omp parallel for
for (int y = m0; y < mmax; y += 4) {
float *outptr = dout + stride * (y - m0) / 4;
const float *inptr0 = inptr + y * ldin + k0;
const float *inptr1 = inptr0 + ldin;
const float *inptr2 = inptr1 + ldin;
const float *inptr3 = inptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(inptr0),
[ptr1] "r"(inptr1),
[ptr2] "r"(inptr2),
[ptr3] "r"(inptr3)
: "memory");
int x = x_len;
//! cope with row index exceed real size, set to zero buffer
if ((y + 3) >= mmax) {
switch ((y + 3) - mmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
asm volatile(
"cbz %w[has_alpha], 0f\n" /* check alpha == 1.f? */
"dup v31.4s, %w[alpha]\n" /* alpha to vector */
"ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */
"ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */
"fmul v0.4s, v31.4s, v0.4s\n" /* mul alpha */
"fmul v1.4s, v31.4s, v1.4s\n" /* mul alpha */
"ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */
"fmul v2.4s, v31.4s, v2.4s\n" /* mul alpha */
"fmul v3.4s, v31.4s, v3.4s\n" /* mul alpha */
"ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */
"fmul v4.4s, v31.4s, v4.4s\n" /* mul alpha */
"fmul v5.4s, v31.4s, v5.4s\n" /* mul alpha */
"fmul v6.4s, v31.4s, v6.4s\n" /* mul alpha */
"fmul v7.4s, v31.4s, v7.4s\n" /* mul alpha */
"b 1f\n" /* to main process */
"0: \n" /* alpha == 1 */
"ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */
"ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */
"ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */
"ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */
"1: \n" /* main process */
"trn1 v8.4s, v0.4s, v2.4s\n" /* a0b0a2b2*/
"trn2 v9.4s, v0.4s, v2.4s\n" /* a1b1a3b3*/
"trn1 v10.4s, v1.4s, v3.4s\n" /* a4b4a6b6*/
"trn2 v11.4s, v1.4s, v3.4s\n" /* a5b5a7b7*/
"trn1 v12.4s, v4.4s, v6.4s\n" /* c0d0c2d2*/
"trn2 v13.4s, v4.4s, v6.4s\n" /* c1d1c3d3*/
"trn1 v14.4s, v5.4s, v7.4s\n" /* c4d4c6d6*/
"trn2 v15.4s, v5.4s, v7.4s\n" /* c5d5c7d7*/
"trn1 v0.2d, v8.2d, v12.2d\n" /* a0b0c0d0 */
"trn1 v1.2d, v9.2d, v13.2d\n" /* a1b1c1d1 */
"trn1 v2.2d, v10.2d, v14.2d\n" /* a4b4c4d4 */
"trn1 v3.2d, v11.2d, v15.2d\n" /* a5b5c5d5 */
"trn2 v4.2d, v8.2d, v12.2d\n" /* a2b2c2d2 */
"trn2 v5.2d, v9.2d, v13.2d\n" /* a3b3c3d3 */
"stp q0, q1, [%[outptr]], #32\n" /* save q0, q1, a0~h0*/
"trn2 v6.2d, v10.2d, v14.2d\n" /* a6b6c6d6 */
"trn2 v7.2d, v11.2d, v15.2d\n" /* a7b7c7d7 */
"stp q4, q5, [%[outptr]], #32\n" /* save q2, q3, a1~h1*/
"stp q2, q3, [%[outptr]], #32\n" /* save q4, q5, a2~h2*/
"stp q6, q7, [%[outptr]], #32\n" /* save q6, q7, a3~h3*/
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3),
[outptr] "+r"(outptr)
: [alpha] "r"(alpha), [has_alpha] "r"(has_alpha)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"cc",
"memory");
}
for (; x > 0; x--) {
if (has_alpha) {
*outptr++ = *inptr0++ * alpha;
*outptr++ = *inptr1++ * alpha;
*outptr++ = *inptr2++ * alpha;
*outptr++ = *inptr3++ * alpha;
} else {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
}
}
}
void prepackA_trans_8x12(float *outptr, void prepackA_trans_8x12(float *outptr,
const float *in, const float *in,
...@@ -682,6 +880,128 @@ void prepackA_trans_8x12(float *outptr, ...@@ -682,6 +880,128 @@ void prepackA_trans_8x12(float *outptr,
} }
} }
} }
void pack_trans_m4(float *outptr,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
auto inptr = in + k0 * ldin + m0;
uint32_t mask_buffer[4] = {0, 1, 2, 3};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 4 * (x_len / 4);
int stride_out = 4 * y_len;
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
float32x4_t valpha = vdupq_n_f32(alpha);
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const float *ptr0 = inptr + y * ldin;
const float *ptr1 = ptr0 + ldin;
const float *ptr2 = ptr1 + ldin;
const float *ptr3 = ptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float32x4_t vr00 = vld1q_f32(ptr0);
float32x4_t vr10 = vld1q_f32(ptr1);
float32x4_t vr20 = vld1q_f32(ptr2);
float32x4_t vr30 = vld1q_f32(ptr3);
if (has_alpha) {
vr00 = vmulq_f32(vr00, valpha);
vr10 = vmulq_f32(vr10, valpha);
vr20 = vmulq_f32(vr20, valpha);
vr30 = vmulq_f32(vr30, valpha);
}
vst1q_f32(outptr_row_col, vr00);
vst1q_f32(outptr_row_col + 4, vr10);
vst1q_f32(outptr_row_col + 8, vr20);
vst1q_f32(outptr_row_col + 12, vr30);
ptr0 += 4;
ptr1 += 4;
ptr2 += 4;
ptr3 += 4;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
float32x4_t vr00 = vld1q_f32(ptr0);
float32x4_t vr10 = vld1q_f32(ptr1);
float32x4_t vr20 = vld1q_f32(ptr2);
float32x4_t vr30 = vld1q_f32(ptr3);
if (has_alpha) {
vr00 = vmulq_f32(vr00, valpha);
vr10 = vmulq_f32(vr10, valpha);
vr20 = vmulq_f32(vr20, valpha);
vr30 = vmulq_f32(vr30, valpha);
}
float32x4_t vr00_1 = vbslq_f32(vmask1, vr00, vzero);
float32x4_t vr10_1 = vbslq_f32(vmask1, vr10, vzero);
float32x4_t vr20_1 = vbslq_f32(vmask1, vr20, vzero);
float32x4_t vr30_1 = vbslq_f32(vmask1, vr30, vzero);
vst1q_f32(outptr_row_col, vr00_1);
vst1q_f32(outptr_row_col + 4, vr10_1);
vst1q_f32(outptr_row_col + 8, vr20_1);
vst1q_f32(outptr_row_col + 12, vr30_1);
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const float *ptr0 = inptr + y * ldin;
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float32x4_t vr0 = vld1q_f32(ptr0);
if (has_alpha) {
vr0 = vmulq_f32(vr0, valpha);
}
vst1q_f32(outptr_row_col, vr0);
ptr0 += 4;
outptr_row_col += stride_out;
}
if (right_remain > 0) {
float32x4_t vr0 = vld1q_f32(ptr0);
if (has_alpha) {
vr0 = vmulq_f32(vr0, valpha);
}
float32x4_t vr0_1 = vbslq_f32(vmask1, vr0, vzero);
vst1q_f32(outptr_row_col, vr0_1);
}
}
}
#else // __aarch64__ #else // __aarch64__
void prepackA_6x8(float* outptr, void prepackA_6x8(float* outptr,
...@@ -2592,6 +2912,292 @@ void sgemm_prepacked_8x12(bool is_transB, ...@@ -2592,6 +2912,292 @@ void sgemm_prepacked_8x12(bool is_transB,
} }
} }
} }
void sgemm_prepacked_4x4(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
bool has_relu,
ARMContext *ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
const int n_block = 4;
const int m_block = 4;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block));
x_block /= n_block;
x_block *= n_block;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + n_block - 1) / n_block;
x_block *= n_block;
x_block = x_block < n_block ? n_block : x_block;
// unroll 2 loop
int tail_pre = (K & (KBLOCK - 1));
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
bool flag_p_remain = false;
int remain = 0;
int has_beta = fabsf(beta) > 1e-8f ? 1 : 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + n_block - 1) / n_block;
remain = xmax - x0 - (bblocks - 1) * n_block;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
float *b_pannel = workspace;
if (is_transB) {
pack_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K);
} else {
pack_trans_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += m_block) {
unsigned int ymax = y + m_block;
if (ymax > M) {
ymax = M;
}
float bias_local[4] = {0};
if (has_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
}
float cout0[n_block]; // NOLINT
float cout1[n_block]; // NOLINT
float cout2[n_block]; // NOLINT
float cout3[n_block]; // NOLINT
float *c_ptr0 = C + y * ldc + x0;
float *c_ptr1 = c_ptr0 + ldc;
float *c_ptr2 = c_ptr1 + ldc;
float *c_ptr3 = c_ptr2 + ldc;
float *pout0 = c_ptr0;
float *pout1 = c_ptr1;
float *pout2 = c_ptr2;
float *pout3 = c_ptr3;
const float *a_ptr_l = A_packed + y * K;
const float *b_ptr_l = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
c_ptr1 = cout1;
case 1:
c_ptr2 = cout2;
case 0:
c_ptr3 = cout3;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
if (has_beta) {
for (int i = 0; i < remain; ++i) {
cout0[i] = pout0[i];
cout1[i] = pout1[i];
cout2[i] = pout2[i];
cout3[i] = pout3[i];
}
}
}
const float *a_ptr = a_ptr_l;
const float *b_ptr = b_ptr_l + xb * K * 4;
int tail = tail_pre;
int k = k_pre;
// clang-format off
asm volatile(
"prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/
"ld1 {v2.4s}, [%[bias_ptr]]\n" /* load bias to q2, q3*/
"dup v8.4s, v2.s[0]\n" /* out0 = 0 */
"prfm pldl1keep, [%[b_ptr]]\n" /* preload b*/
"dup v9.4s, v2.s[1]\n" /* out1 = 0*/
"prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/
"dup v10.4s, v2.s[2]\n" /* out2 = 0*/
"prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/
"dup v11.4s, v2.s[3]\n" /* out3 = 0*/
"cbz %w[has_beta], 0f\n" /* check beta == 0? */
/* process beta */
"dup v7.4s, %w[beta]\n" /* beta to vector */
"ld1 {v0.4s}, [%[c_ptr0]]\n" /* load output r0 */
"ld1 {v1.4s}, [%[c_ptr1]]\n" /* load output r1 */
"fmla v8.4s, v0.4s, v7.4s\n" /* cr00 += beta * c_r00*/
"fmla v9.4s, v1.4s, v7.4s\n" /* cr10 += beta * c_r10*/
"ld1 {v2.4s}, [%[c_ptr2]]\n"
"ld1 {v3.4s}, [%[c_ptr3]]\n"
"fmla v10.4s, v2.4s, v7.4s\n" /* cr20 += beta * c_r20*/
"fmla v11.4s, v3.4s, v7.4s\n" /* cr30 += beta * c_r30*/
"0: \n" /* check loop count */
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a10 to q0, q1*/
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/
"cbz %w[k], 2f\n" /* check loop count > 0 */
/* main loop */
/* unrool 0*/
"1:\n" /* main loop */
"fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */
"fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */
"fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5 */
"fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */
"ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/
"fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b2 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b2 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b2 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b2 =q6*/
"ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a10 to q0, q1 */
"fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/
"fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/
"subs %w[k], %w[k], #1\n" /* loop count - 1*/
"fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/
"fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/
"bne 1b\n"
"2:\n" /* process tail*/
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"beq 3f\n" /*jump to tail = 1*/
/* final unrool 0*/
/* unrool 0, tail > 1*/
"fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */
"beq 4f\n" /*jump to tail = 2*/
/* unrool 1, tail > 2*/
"ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */
"fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/
"subs %w[tail], %w[tail], #1\n" /* tail--*/
"fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */
"ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */
"beq 5f\n" /*jump to tail = 3*/
/* unrool 2, tail = 4*/
"fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/
/* unrool 3, tail = 4*/
"fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/
"fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/
"fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/
"fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/
"b 11f\n"
/* tails==1 final tail*/
"3: \n" /* tail=1*/
"fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */
"fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */
"b 11f\n"
/* tails==2 final tail*/
"4:\n" /* tail = 2*/
"fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/
"fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */
"b 11f\n"
/* tails==3 final tail*/
"5:\n" /* tail = 3*/
"fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/
"11: \n" /* check if relu */
"cbz %w[relu], 12f\n" /* skip relu */
"movi v2.4s, #0\n" /* for relu*/
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/
"12: \n"
"st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */
"st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */
"st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */
"st1 {v11.4s}, [%[c_ptr3]], #16\n" /* store r3 */
: [a_ptr] "+r"(a_ptr),
[b_ptr] "+r"(b_ptr),
[k] "+r"(k),
[tail] "+r"(tail),
[c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1),
[c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3)
: [bias_ptr] "r"(bias_local),
[relu] "r"(has_relu),
[has_beta] "r"(has_beta),
[beta] "r"(beta)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11");
// clang-format on
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
}
}
}
}
}
}
#else // __aarch64__ #else // __aarch64__
/** /**
* \brief gemm with ablock = 6, bblock = 8, output 6x8 * \brief gemm with ablock = 6, bblock = 8, output 6x8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册