未验证 提交 47d21d28 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add v8-4x8 gemm implement (#4201)

* add v8 4x8 implment. test=devvelop

* fix run error

* change sgemm compute. test=develop

* fix format. test=develop
上级 03951c1a
......@@ -55,6 +55,39 @@ void sgemm_prepacked_8x12(bool is_transB,
const operators::ActivationParam act_param,
ARMContext *ctx);
void prepackA_4x8(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void prepackA_trans_4x8(float *out,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax);
void sgemm_prepacked_4x8(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
const operators::ActivationParam act_param,
ARMContext *ctx);
void pack_m4(float *out,
const float *in,
float alpha,
......@@ -189,9 +222,9 @@ void prepackA(float *out,
#ifdef __aarch64__
if (mmax <= 4) {
if (is_trans) {
pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax);
} else {
pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax);
prepackA_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax);
}
} else {
if (is_trans) {
......@@ -269,7 +302,7 @@ void sgemm_prepack(bool is_transB,
ARMContext *ctx) {
#ifdef __aarch64__
if (M <= 4) {
sgemm_prepacked_4x4(is_transB,
sgemm_prepacked_4x8(is_transB,
M,
N,
K,
......@@ -633,6 +666,145 @@ void prepackA_8x12(float *dout,
}
}
}
void prepackA_4x8(float *outptr,
const float *inptr,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
int x_len = kmax - k0;
float zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(float) * x_len);
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
float32x4_t valpha = vdupq_n_f32(alpha);
#pragma omp parallel for
for (int y = m0; y < mmax; y += 4) {
const float *inptr0 = inptr + y * ldin + k0;
const float *inptr1 = inptr0 + ldin;
const float *inptr2 = inptr1 + ldin;
const float *inptr3 = inptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(inptr0),
[ptr1] "r"(inptr1),
[ptr2] "r"(inptr2),
[ptr3] "r"(inptr3)
: "memory");
int x = x_len;
if ((y + 3) >= mmax) {
switch ((y + 3) - mmax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
// clang-format off
asm volatile(
"cbz %w[has_alpha], 0f\n"
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"fmul v4.4s, v4.4s, %[alpha].4s\n"
"fmul v5.4s, v5.4s, %[alpha].4s\n"
"fmul v6.4s, v6.4s, %[alpha].4s\n"
"fmul v7.4s, v7.4s, %[alpha].4s\n"
"b 1f\n" // to main process
"0: \n" // alpha == 1
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"1: \n"
"trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2
"trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3
"trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6
"trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7
"trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2
"trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3
"trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6
"trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7
"trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0
"trn1 v1.2d, v9.2d, v13.2d\n" // a1b1c1d1
"trn2 v2.2d, v8.2d, v12.2d\n" // a2b2c2d2
"trn2 v3.2d, v9.2d, v13.2d\n" // a3b3c3d3
"st1 {v0.4s}, [%[outptr]], #16\n"
"trn1 v4.2d, v10.2d, v14.2d\n" // a4b4c4d4
"st1 {v1.4s}, [%[outptr]], #16\n"
"trn1 v5.2d, v11.2d, v15.2d\n" // a5b5c5d5
"st1 {v2.4s}, [%[outptr]], #16\n"
"trn2 v6.2d, v10.2d, v14.2d\n" // a6b6c6d6
"st1 {v3.4s}, [%[outptr]], #16\n"
"trn2 v7.2d, v11.2d, v15.2d\n" // a7b7c7d7
"stp q4, q5, [%[outptr]], #32\n"
"stp q6, q7, [%[outptr]], #32\n"
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3),
[outptr] "+r"(outptr)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
// clang-format on
}
for (; x > 0; x--) {
if (has_alpha) {
*outptr++ = *inptr0++ * alpha;
*outptr++ = *inptr1++ * alpha;
*outptr++ = *inptr2++ * alpha;
*outptr++ = *inptr3++ * alpha;
} else {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
}
}
}
void pack_m4(float *dout,
const float *inptr,
float alpha,
......@@ -934,6 +1106,159 @@ void prepackA_trans_8x12(float *outptr,
}
}
}
void prepackA_trans_4x8(float *outptr,
const float *in,
float alpha,
int ldin,
int m0,
int mmax,
int k0,
int kmax) {
auto inptr = in + k0 * ldin + m0;
bool has_alpha = fabsf(alpha - 1.f) > 1e-8f;
float32x4_t valpha = vdupq_n_f32(alpha);
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = mmax - m0;
int y_len = kmax - k0;
int right_remain = x_len - 4 * (x_len / 4);
int right_pad = 4 - right_remain;
if (right_remain == 0) {
right_pad = 0;
}
int stride_out = 4 * y_len;
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const float *ptr0 = inptr + y * ldin;
const float *ptr1 = ptr0 + ldin;
const float *ptr2 = ptr1 + ldin;
const float *ptr3 = ptr2 + ldin;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float *ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"ld1 {v1.4s}, [%[ptr1]], #16\n"
"ld1 {v2.4s}, [%[ptr2]], #16\n"
"ld1 {v3.4s}, [%[ptr3]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"0: \n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "v0", "v1", "v2", "v3", "cc", "memory");
// clang-format on
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float *ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"ld1 {v1.4s}, [%[ptr1]], #16\n"
"ld1 {v2.4s}, [%[ptr2]], #16\n"
"ld1 {v3.4s}, [%[ptr3]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"fmul v1.4s, v1.4s, %[alpha].4s\n"
"fmul v2.4s, v2.4s, %[alpha].4s\n"
"fmul v3.4s, v3.4s, %[alpha].4s\n"
"0: \n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask1].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1),
[vzero] "w"(vzero),
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "v0", "v1", "v2", "v3", "cc", "memory");
// clang-format on
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const float *ptr0 = inptr + y * ldin;
float *outptr_row_col = outptr + y * 4;
int i = 0;
for (; i < x_len - 3; i += 4) {
float *ptr_out = outptr_row_col;
asm volatile(
"cmp %[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"0: \n"
"st1 {v0.4s}, [%[outptr]], #16\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "v0", "v1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float *ptr_out = outptr_row_col;
asm volatile(
"cmp %w[has_alpha], #0\n"
"ld1 {v0.4s}, [%[ptr0]], #16\n"
"beq 0f\n"
"1: \n"
"fmul v0.4s, v0.4s, %[alpha].4s\n"
"0: \n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"st1 {v0.4s}, [%[outptr]], #16\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1),
[vzero] "w"(vzero),
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "v0", "v1", "cc", "memory");
}
}
}
void pack_trans_m4(float *outptr,
const float *in,
float alpha,
......@@ -1587,6 +1912,7 @@ void prepackA_trans_4x8(float* outptr,
int i = 0;
for (; i < x_len - 3; i += 4) {
float* ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"cmp %[has_alpha], #0\n"
......@@ -1597,10 +1923,12 @@ void prepackA_trans_4x8(float* outptr,
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [has_alpha] "r"(has_alpha), [alpha] "w"(valpha)
: "q0", "q1", "cc", "memory");
// clang-format on
outptr_row_col += stride_out;
}
if (right_pad > 0) {
float* ptr_out = outptr_row_col;
// clang-format off
asm volatile(
"vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n"
"cmp %[has_alpha], #0\n"
......@@ -1615,6 +1943,7 @@ void prepackA_trans_4x8(float* outptr,
[has_alpha] "r"(has_alpha),
[alpha] "w"(valpha)
: "q0", "q1", "cc", "memory");
// clang-format on
}
}
}
......@@ -1624,7 +1953,7 @@ void prepackA_trans_4x8(float* outptr,
/**
* \brief input data is transpose
* for arm-v7a, transform data to block x k x 8 layout
* for arm-v8a, transform data to block x k x 12 layout
* for arm-v8a, transform data to block x k x 12 layout or block x k x 8 layout
*/
#ifdef __aarch64__
void loadb(
......@@ -2037,19 +2366,17 @@ void loadb_trans(
}
}
}
#else // __aarch64__
void loadb(
float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax) {
auto outptr = reinterpret_cast<uint32_t*>(out);
auto inptr = reinterpret_cast<const uint32_t*>(in) + k0 * ldin + n0;
void loadb_eight(
float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) {
auto outptr = reinterpret_cast<uint32_t *>(out);
auto inptr = reinterpret_cast<const uint32_t *>(in) + k0 * ldin + n0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = nmax - n0;
int y_len = kmax - k0;
int right_remain = x_len - 8 * (x_len / 8);
int right_pad = 8 - right_remain;
uint32_t* outptr_row = outptr;
uint32_t *outptr_row = outptr;
int stride_out = 8 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
......@@ -2060,14 +2387,287 @@ void loadb(
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t* ptr0 = inptr + y * ldin;
const uint32_t* ptr1 = ptr0 + ldin;
const uint32_t* ptr2 = ptr1 + ldin;
const uint32_t* ptr3 = ptr2 + ldin;
uint32_t* outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t* ptr_out = outptr_row_col;
const uint32_t *ptr0 = inptr + y * ldin;
const uint32_t *ptr1 = ptr0 + ldin;
const uint32_t *ptr2 = ptr1 + ldin;
const uint32_t *ptr3 = ptr2 + ldin;
uint32_t *outptr_row_col = outptr_row + y * 8;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
:
: [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3)
: "memory");
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n" // load r0, 8 elements
"ldp q2, q3, [%[ptr1]], #32\n" // load r1, 8 elements
"stp q0, q1, [%[outptr]], #32\n" // write to output ptr
"stp q2, q3, [%[outptr]], #32\n" // write to output ptr
"ldp q0, q1, [%[ptr2]], #32\n" // load r0, 8 elements
"ldp q2, q3, [%[ptr3]], #32\n" // load r1, 8 elements
"stp q0, q1, [%[outptr]], #32\n" // write to output ptr
"stp q2, q3, [%[outptr]], #32\n" // write to output ptr
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
:
: "v0", "v1", "v2", "v3", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"ldp q2, q3, [%[ptr1]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"ldp q0, q1, [%[ptr2]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
"ldp q2, q3, [%[ptr3]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"bif v2.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v3.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
"stp q2, q3, [%[outptr]], #32\n"
: [outptr] "+r"(ptr_out),
[ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2),
[ptr3] "+r"(ptr3)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "v0", "v1", "v2", "v3", "cc", "memory");
}
}
#pragma omp parallel for
for (int y = 4 * (y_len / 4); y < y_len; ++y) {
const uint32_t *ptr0 = inptr + y * ldin;
uint32_t *outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"stp q0, q1, [%[outptr]], #32\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
:
: "v0", "v1", "cc", "memory");
outptr_row_col += stride_out;
}
if (right_remain > 0) {
uint32_t *ptr_out = outptr_row_col;
asm volatile(
"ldp q0, q1, [%[ptr0]], #32\n"
"bif v0.16b, %[vzero].16b, %[vmask1].16b\n"
"bif v1.16b, %[vzero].16b, %[vmask2].16b\n"
"stp q0, q1, [%[outptr]], #32\n"
: [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero)
: "v0", "v1", "cc", "memory");
}
}
}
void loadb_trans_eight(
float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) {
int x_len = kmax - k0;
uint32_t zerobuff[x_len]; // NOLINT
memset(zerobuff, 0, sizeof(uint32_t) * x_len);
auto outptr = reinterpret_cast<uint32_t *>(out);
auto inptr = reinterpret_cast<const uint32_t *>(in);
//! data B is not transposed, transpose B to k * 8
for (int y = n0; y < nmax; y += 8) {
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
const uint32_t *inptr3 = inptr2 + ldin;
const uint32_t *inptr4 = inptr3 + ldin;
const uint32_t *inptr5 = inptr4 + ldin;
const uint32_t *inptr6 = inptr5 + ldin;
const uint32_t *inptr7 = inptr6 + ldin;
int x = x_len;
asm volatile(
"prfm pldl1keep, [%[ptr0]] \n"
"prfm pldl1keep, [%[ptr0], #64] \n"
"prfm pldl1keep, [%[ptr1]] \n"
"prfm pldl1keep, [%[ptr1], #64] \n"
"prfm pldl1keep, [%[ptr2]] \n"
"prfm pldl1keep, [%[ptr2], #64] \n"
"prfm pldl1keep, [%[ptr3]] \n"
"prfm pldl1keep, [%[ptr3], #64] \n"
"prfm pldl1keep, [%[ptr4]] \n"
"prfm pldl1keep, [%[ptr4], #64] \n"
"prfm pldl1keep, [%[ptr5]] \n"
"prfm pldl1keep, [%[ptr5], #64] \n"
"prfm pldl1keep, [%[ptr6]] \n"
"prfm pldl1keep, [%[ptr6], #64] \n"
"prfm pldl1keep, [%[ptr7]] \n"
"prfm pldl1keep, [%[ptr7], #64] \n"
:
: [ptr0] "r"(inptr0),
[ptr1] "r"(inptr1),
[ptr2] "r"(inptr2),
[ptr3] "r"(inptr3),
[ptr4] "r"(inptr4),
[ptr5] "r"(inptr5),
[ptr6] "r"(inptr6),
[ptr7] "r"(inptr7)
: "memory");
//! cope with row index exceed real size, set to zero buffer
if ((y + 7) >= nmax) {
switch ((y + 7) - nmax) {
case 6:
inptr1 = zerobuff;
case 5:
inptr2 = zerobuff;
case 4:
inptr3 = zerobuff;
case 3:
inptr4 = zerobuff;
case 2:
inptr5 = zerobuff;
case 1:
inptr6 = zerobuff;
case 0:
inptr7 = zerobuff;
default:
break;
}
}
for (; x > 7; x -= 8) {
// clang-format off
//! zip load 8 elements (2 neon Q registers) from each of 8 rows
asm volatile(
"ldp q0, q1, [%[inptr0]], #32\n" // load r0, a0~a7
"ldp q2, q3, [%[inptr1]], #32\n" // load r1, b0~b7
"ldp q4, q5, [%[inptr2]], #32\n" // load r2, c0~c7
"ldp q6, q7, [%[inptr3]], #32\n" // load r3, d0~d7
"trn1 v8.4s, v0.4s, v2.4s\n" // a0b0a2b2
"trn2 v9.4s, v0.4s, v2.4s\n" // a1b1a3b3
"trn1 v10.4s, v1.4s, v3.4s\n" // a4b4a6b6
"trn2 v11.4s, v1.4s, v3.4s\n" // a5b5a7b7
"ldp q16, q17, [%[inptr4]], #32\n"// load r4, e0~e7
"ldp q18, q19, [%[inptr5]], #32\n"// load r5, f0~f7
"ldp q20, q21, [%[inptr6]], #32\n"// load r6, g0~g7
"ldp q22, q23, [%[inptr7]], #32\n"// load r7, h0~h7
"trn1 v12.4s, v4.4s, v6.4s\n" // c0d0c2d2
"trn2 v13.4s, v4.4s, v6.4s\n" // c1d1c3d3
"trn1 v14.4s, v5.4s, v7.4s\n" // c4d4c6d6
"trn2 v15.4s, v5.4s, v7.4s\n" // c5d5c7d7
"trn1 v24.4s, v16.4s, v18.4s\n" // e0f0e2f2
"trn2 v25.4s, v16.4s, v18.4s\n" // e1f1e3f3
"trn1 v28.4s, v20.4s, v22.4s\n" // g0h0e2f2
"trn2 v29.4s, v20.4s, v22.4s\n" // g1h1e3f3
"trn1 v26.4s, v17.4s, v19.4s\n" // e4f4e6f6
"trn2 v27.4s, v17.4s, v19.4s\n" // e5f5e7f7
"trn1 v30.4s, v21.4s, v23.4s\n" // g4h4e6f6
"trn2 v31.4s, v21.4s, v23.4s\n" // g5h5e7f7
"trn1 v0.2d, v8.2d, v12.2d\n" // a0b0c0d0
"trn1 v1.2d, v24.2d, v28.2d\n" // e0f0g0h0
"trn1 v2.2d, v9.2d, v13.2d\n" // a1b1c1d1
"trn1 v3.2d, v25.2d, v29.2d\n" // e1f1g1h1
"trn2 v4.2d, v8.2d, v12.2d\n" // a2b2c2d2
"trn2 v5.2d, v24.2d, v28.2d\n" // e2f2g2h2
"stp q0, q1, [%[outptr]], #32\n" // save q0, q1, a0~h0
"trn2 v6.2d, v9.2d, v13.2d\n" // a3b3c3d3
"trn2 v7.2d, v25.2d, v29.2d\n" // e3f3g3h3
"stp q2, q3, [%[outptr]], #32\n" // save q0, q1, a1~h1
"trn1 v16.2d, v10.2d, v14.2d\n" // a4b4c4d4
"trn1 v17.2d, v26.2d, v30.2d\n" // e4f4g4h4
"stp q4, q5, [%[outptr]], #32\n" // save q0, q1, a2~h2
"trn1 v18.2d, v11.2d, v15.2d\n" // a5b5c5d5
"trn1 v19.2d, v27.2d, v31.2d\n" // e5f5g5h5
"stp q6, q7, [%[outptr]], #32\n" // save q0, q1, a3~h3
"trn2 v20.2d, v10.2d, v14.2d\n" // a6b6c6d6
"trn2 v21.2d, v26.2d, v30.2d\n" // e6f6g6h6
"stp q16, q17, [%[outptr]], #32\n" // save q0, q1, a4~h4
"trn2 v22.2d, v11.2d, v15.2d\n" // a7b7c7d7
"trn2 v23.2d, v27.2d, v31.2d\n" // e7f7g7h7
"stp q18, q19, [%[outptr]], #32\n" // save q0, q1, a5~h5
"stp q20, q21, [%[outptr]], #32\n" // save q0, q1, a6~h6
"stp q22, q23, [%[outptr]], #32\n" // save q0, q1, a7~h7
: [inptr0] "+r"(inptr0),
[inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3),
[inptr4] "+r"(inptr4),
[inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6),
[inptr7] "+r"(inptr7),
[outptr] "+r"(outptr)
:
: "v0","v1","v2","v3","v4","v5",
"v6","v7","v8","v9","v10","v11","v12",
"v13","v14","v15","v16","v17","v18","v19",
"v20","v21","v22","v23","v24","v25","v26",
"v27","v28","v29","v30","v31","cc","memory");
// clang-format on
}
for (; x > 0; x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
*outptr++ = *inptr4++;
*outptr++ = *inptr5++;
*outptr++ = *inptr6++;
*outptr++ = *inptr7++;
}
}
}
#else // __aarch64__
void loadb(
float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax) {
auto outptr = reinterpret_cast<uint32_t*>(out);
auto inptr = reinterpret_cast<const uint32_t*>(in) + k0 * ldin + n0;
uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
int x_len = nmax - n0;
int y_len = kmax - k0;
int right_remain = x_len - 8 * (x_len / 8);
int right_pad = 8 - right_remain;
uint32_t* outptr_row = outptr;
int stride_out = 8 * y_len;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 =
vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain));
uint32x4_t vmask2 =
vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain));
#pragma omp parallel for
for (int y = 0; y < y_len - 3; y += 4) {
const uint32_t* ptr0 = inptr + y * ldin;
const uint32_t* ptr1 = ptr0 + ldin;
const uint32_t* ptr2 = ptr1 + ldin;
const uint32_t* ptr3 = ptr2 + ldin;
uint32_t* outptr_row_col = outptr_row + y * 8;
int i = 0;
for (; i < x_len - 7; i += 8) {
uint32_t* ptr_out = outptr_row_col;
asm volatile(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n"
"vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n"
......@@ -3129,6 +3729,419 @@ void sgemm_prepacked_8x12(bool is_transB,
}
}
void sgemm_prepacked_4x8(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float beta,
float *C,
int ldc,
const float *bias,
bool has_bias,
const operators::ActivationParam act_param,
ARMContext *ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto *workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leaky: 4
const int n_block = 8;
const int m_block = 4;
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
float32x4_t valpha = vld1q_f32(alpha);
float32x4_t vzero = vdupq_n_f32(0.f);
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block));
x_block /= n_block;
x_block *= n_block;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + n_block - 1) / n_block;
x_block *= n_block;
x_block = x_block < n_block ? n_block : x_block;
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
int tail_pre = (K & (KBLOCK - 1));
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
bool flag_p_remain = false;
int remain = 0;
int has_beta = fabsf(beta) > 1e-8f ? 1 : 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + n_block - 1) / n_block;
remain = xmax - x0 - (bblocks - 1) * n_block;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
auto b_pannel = static_cast<float *>(workspace);
if (is_transB) {
loadb_trans_eight(b_pannel, B, ldb, 0, K, x0, xmax);
} else {
loadb_eight(b_pannel, B, ldb, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += m_block) {
unsigned int ymax = y + m_block;
if (ymax > M) {
ymax = M;
}
float cout0[n_block]; // NOLINT
float cout1[n_block]; // NOLINT
float cout2[n_block]; // NOLINT;
float cout3[n_block]; // NOLINT
float bias_local[4] = {0};
if (has_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
}
float *c_ptr0 = C + y * ldc + x0;
float *c_ptr1 = c_ptr0 + ldc;
float *c_ptr2 = c_ptr1 + ldc;
float *c_ptr3 = c_ptr2 + ldc;
float *pout0 = c_ptr0;
float *pout1 = c_ptr1;
float *pout2 = c_ptr2;
float *pout3 = c_ptr3;
const float *a_ptr_l = A_packed + y * K;
const float *b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
c_ptr1 = cout1;
case 1:
c_ptr2 = cout1;
case 0:
c_ptr3 = cout1;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
if (has_beta) {
for (int i = 0; i < remain; ++i) {
cout0[i] = pout0[i];
cout1[i] = pout1[i];
cout2[i] = pout2[i];
cout3[i] = pout3[i];
}
}
}
const float *a_ptr = a_ptr_l;
int tails = tail_pre;
int k = k_pre;
// clang-format off
asm volatile(
"ld1 {v2.4s}, [%[bias_ptr]]\n"
"dup v8.4s, v2.s[0]\n"
"prfm pldl1keep, [%[a_ptr]]\n"
"dup v9.4s, v2.s[0]\n"
"prfm pldl1keep, [%[b_ptr]]\n"
"dup v10.4s, v2.s[1]\n"
"dup v11.4s, v2.s[1]\n"
"prfm pldl1keep, [%[a_ptr], #64]\n"
"dup v12.4s, v2.s[2]\n"
"dup v13.4s, v2.s[2]\n"
"prfm pldl1keep, [%[b_ptr], #64]\n"
"dup v14.4s, v2.s[3]\n"
"dup v15.4s, v2.s[3]\n"
"prfm pldl1keep, [%[a_ptr], #128]\n"
"cmp %w[beta], #0\n" // check beta == 0?
"prfm pldl1keep, [%[b_ptr], #128]\n"
"prfm pldl1keep, [%[b_ptr], #192]\n"
// process beta
"beq 11f\n"
"dup v7.4s, %w[beta]\n" // beta to vector
"ld1 {v0.4s, v1.4s}, [%[c_ptr0]]\n" // load output r0
"ld1 {v2.4s, v3.4s}, [%[c_ptr1]]\n" // load output r1
"fmla v8.4s, v0.4s, v7.4s\n" // cr00 += beta * c_r00
"fmla v9.4s, v1.4s, v7.4s\n" // cr01 += beta * c_r01
"ld1 {v0.4s, v1.4s}, [%[c_ptr2]]\n" // load output r2
"fmla v10.4s, v2.4s, v7.4s\n" // cr10 += beta * c_r10
"fmla v11.4s, v3.4s, v7.4s\n" // cr11 += beta * c_r11
"ld1 {v2.4s, v3.4s}, [%[c_ptr3]]\n" // load output r3
"fmla v12.4s, v0.4s, v7.4s\n" // cr20 += beta * c_r20
"fmla v13.4s, v1.4s, v7.4s\n" // cr21 += beta * c_r21
"fmla v14.4s, v2.4s, v7.4s\n" // cr30 += beta * c_r30
"fmla v15.4s, v3.4s, v7.4s\n" // cr31 += beta * c_r31
"11: \n" // check loop count
"ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
"cbz %w[k], 0f\n" // check loop count > 0
// main loop
// Unroll 0
"1: \n"
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
// Unroll 1
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"prfm pldl1keep, [%[b_ptr], #192]\n"
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
// Unroll 2
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"ldp q0, q1, [%[a_ptr]], #32\n" // load a0~a3 to q0, q1
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
// Unroll 3
"fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0]
"prfm pldl1keep, [%[a_ptr], #128]\n"
"fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3]
"subs %w[k], %w[k], #1\n" // loop count - 1
"fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3]
"bne 1b\n"
"0: \n"
"subs %w[tail], %w[tail], #1\n" // tail--
"beq 3f\n" // jump to tail = 1
// Unroll 0
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"subs %w[tail], %w[tail], #1\n" // tail--
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
"beq 4f\n" // jump to tail = 2
// Unroll 1
"ldp q4, q5, [%[b_ptr]], #32\n" // load b0~b3 to q4, q5
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"ldp q2, q3, [%[a_ptr]], #32\n" // load next 2xa0~a3
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"subs %w[tail], %w[tail], #1\n" // tail--*/
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"beq 5f\n" // jump to tail = 3
// Unroll 2
"ldp q6, q7, [%[b_ptr]], #32\n" // load next b1, b2
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
// Unroll 3
"fmla v8.4s, v6.4s, v3.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v6.4s, v3.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v3.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v3.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v3.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v3.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v3.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v3.s[3]\n" // out7 += b1 * a0[3]
"b 2f\n"
// tails==1 final tail
"3: \n"
"fmla v8.4s, v4.4s, v0.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v0.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v0.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v0.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v0.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v0.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v0.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v0.s[3]\n" // out7 += b1 * a0[3]
// aptr - 16
"sub %w[a_ptr], %w[a_ptr], #16\n"
"b 2f\n"
"4: \n"
// tails==2 final tail
"fmla v8.4s, v6.4s, v1.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v6.4s, v1.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v6.4s, v1.s[2]\n" // out1 += b0 * a0[2]
"fmla v14.4s, v6.4s, v1.s[3]\n" // out1 += b0 * a0[3]
"fmla v9.4s, v7.4s, v1.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v7.4s, v1.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v7.4s, v1.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v7.4s, v1.s[3]\n" // out7 += b1 * a0[3]
"b 2f\n"
// tails==3 final tail
"5: \n"
"fmla v8.4s, v4.4s, v2.s[0]\n" // out0 += b0 * a0[0]
"fmla v10.4s, v4.4s, v2.s[1]\n" // out1 += b0 * a0[1]
"fmla v12.4s, v4.4s, v2.s[2]\n" // out2 += b0 * a0[2]
"fmla v14.4s, v4.4s, v2.s[3]\n" // out3 += b0 * a0[3]
"fmla v9.4s, v5.4s, v2.s[0]\n" // out4 += b1 * a0[0]
"fmla v11.4s, v5.4s, v2.s[1]\n" // out5 += b1 * a0[1]
"fmla v13.4s, v5.4s, v2.s[2]\n" // out6 += b1 * a0[2]
"fmla v15.4s, v5.4s, v2.s[3]\n" // out7 += b1 * a0[3]
// aptr - 16
"sub %w[a_ptr], %w[a_ptr], #16\n"
"2: \n"
"cmp %w[flag_act], #0\n" // check if has act
"beq 10f\n"
"cmp %w[flag_act], #1\n" // check if has relu
"bne 6f\n"
"fmax v8.4s, v8.4s, %[vzero].4s\n"
"fmax v9.4s, v9.4s, %[vzero].4s\n"
"fmax v10.4s, v10.4s, %[vzero].4s\n"
"fmax v11.4s, v11.4s, %[vzero].4s\n"
"fmax v12.4s, v12.4s, %[vzero].4s\n"
"fmax v13.4s, v13.4s, %[vzero].4s\n"
"fmax v14.4s, v14.4s, %[vzero].4s\n"
"fmax v15.4s, v15.4s, %[vzero].4s\n"
"b 10f\n" // end
"6: \n"
"cmp %w[flag_act], #2\n" // check relu6
"bne 7f\n"
"fmax v8.4s, v8.4s, %[vzero].4s\n"
"fmax v9.4s, v9.4s, %[vzero].4s\n"
"fmax v10.4s, v10.4s, %[vzero].4s\n"
"fmax v11.4s, v11.4s, %[vzero].4s\n"
"fmax v12.4s, v12.4s, %[vzero].4s\n"
"fmax v13.4s, v13.4s, %[vzero].4s\n"
"fmax v14.4s, v14.4s, %[vzero].4s\n"
"fmax v15.4s, v15.4s, %[vzero].4s\n"
"fmin v8.4s, v8.4s, %[valpha].4s\n"
"fmin v9.4s, v9.4s, %[valpha].4s\n"
"fmin v10.4s, v10.4s, %[valpha].4s\n"
"fmin v11.4s, v11.4s, %[valpha].4s\n"
"fmin v12.4s, v12.4s, %[valpha].4s\n"
"fmin v13.4s, v13.4s, %[valpha].4s\n"
"fmin v14.4s, v14.4s, %[valpha].4s\n"
"fmin v15.4s, v15.4s, %[valpha].4s\n"
"b 10f\n"
"7: \n"
"fcmge v2.4s, v8.4s, %[vzero].4s\n"
"fmul v3.4s, v8.4s, %[valpha].4s\n"
"fcmge v4.4s, v9.4s, %[vzero].4s\n"
"fmul v5.4s, v9.4s, %[valpha].4s\n"
"fcmge v6.4s, v10.4s, %[vzero].4s\n"
"fmul v7.4s, v10.4s, %[valpha].4s\n"
"fcmge v0.4s, v11.4s, %[vzero].4s\n"
"fmul v1.4s, v11.4s, %[valpha].4s\n"
"bif v8.16b, v3.16b, v2.16b \n"
"bif v9.16b, v5.16b, v4.16b \n"
"bif v10.16b, v7.16b, v6.16b \n"
"bif v11.16b, v1.16b, v0.16b \n"
"fcmge v2.4s, v12.4s, %[vzero].4s\n"
"fmul v3.4s, v12.4s, %[valpha].4s\n"
"fcmge v4.4s, v13.4s, %[vzero].4s\n"
"fmul v5.4s, v13.4s, v1.4s \n"
"fcmge v6.4s, v14.4s, %[vzero].4s\n"
"fmul v7.4s, v14.4s, %[valpha].4s\n"
"fcmge v0.4s, v15.4s, %[vzero].4s\n"
"fmul v1.4s, v15.4s, %[valpha].4s\n"
"bif v12.16b, v3.16b, v2.16b \n"
"bif v13.16b, v5.16b, v4.16b \n"
"bif v14.16b, v7.16b, v6.16b \n"
"bif v15.16b, v1.16b, v0.16b \n"
"10: \n"
"st1 {v8.4s, v9.4s},[%[c_ptr0]], #32\n"
"st1 {v10.4s, v11.4s},[%[c_ptr1]], #32\n"
"st1 {v12.4s, v13.4s},[%[c_ptr2]], #32\n"
"st1 {v14.4s, v15.4s},[%[c_ptr3]], #32\n"
: [a_ptr] "+r"(a_ptr),
[b_ptr] "+r"(b_ptr),
[c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1),
[c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3),
[k] "+r"(k),
[tail] "+r"(tails)
: [bias_ptr] "r"(bias_local),
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act),
[vzero] "w"(vzero),
[valpha] "w"(valpha)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12","v13",
"v14","v15");
// clang-format on
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
}
}
}
}
}
}
void sgemm_prepacked_4x4(bool is_transB,
int M,
int N,
......
......@@ -38,9 +38,10 @@ void sgemm(bool is_transA,
ARMContext* ctx) {
int hblock = get_hblock(ctx);
int m_roundup = hblock * ((M + hblock - 1) / hblock);
ctx->ExtendWorkspace(m_roundup * K * sizeof(float));
auto packed_A = static_cast<float*>(
TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(float)));
auto packed_A = static_cast<float*>(ctx->workspace_data<float>()) +
ctx->llc_size() / sizeof(float);
prepackA(packed_A, A, alpha, lda, 0, M, 0, K, is_transA, ctx);
......@@ -58,7 +59,6 @@ void sgemm(bool is_transA,
is_bias,
act_param,
ctx);
TargetFree(TargetType::kARM, packed_A);
}
} // namespace math
......
......@@ -39,7 +39,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm: M");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册