提交 ac461fd2 编写于 作者: H hjchen2

Refine gemm to support arm64

上级 79618219
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef __ARM_NEON__ #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -22,10 +22,95 @@ namespace paddle_mobile { ...@@ -22,10 +22,95 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
#ifdef __aarch64__ #if __aarch64__
void sgemm_12x8(const float *lhs, const float *rhs, const int k, float *output, void sgemm_6x16(const float *lhs, const float *rhs, const int k, float *output,
const int ldc) { const int ldc) {
// TODO(hjchen2) int kc1 = k;
int step = 4 * ldc;
int step1 = 4 * 6;
asm volatile(
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"dup v29.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"prfm pldl1keep, [%[lhs], #24] \n\t"
"prfm pldl1keep, [%[rhs], #64] \n\t"
"ld1 {v0.4s, v1.4s}, [%[lhs]], %[step1] \n\t"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[rhs]], #64 \n\t"
"fmla v6.4s, v2.4s, v0.s[0] \n\t"
"fmla v7.4s, v3.4s, v0.s[0] \n\t"
"fmla v8.4s, v4.4s, v0.s[0] \n\t"
"fmla v9.4s, v5.4s, v0.s[0] \n\t"
"fmla v10.4s, v2.4s, v0.s[1] \n\t"
"fmla v11.4s, v3.4s, v0.s[1] \n\t"
"fmla v12.4s, v4.4s, v0.s[1] \n\t"
"fmla v13.4s, v5.4s, v0.s[1] \n\t"
"fmla v14.4s, v2.4s, v0.s[2] \n\t"
"fmla v15.4s, v3.4s, v0.s[2] \n\t"
"fmla v16.4s, v4.4s, v0.s[2] \n\t"
"fmla v17.4s, v5.4s, v0.s[2] \n\t"
"fmla v18.4s, v2.4s, v0.s[3] \n\t"
"fmla v19.4s, v3.4s, v0.s[3] \n\t"
"fmla v20.4s, v4.4s, v0.s[3] \n\t"
"fmla v21.4s, v5.4s, v0.s[3] \n\t"
"fmla v22.4s, v2.4s, v1.s[0] \n\t"
"fmla v23.4s, v3.4s, v1.s[0] \n\t"
"fmla v24.4s, v4.4s, v1.s[0] \n\t"
"fmla v25.4s, v5.4s, v1.s[0] \n\t"
"fmla v26.4s, v2.4s, v1.s[1] \n\t"
"fmla v27.4s, v3.4s, v1.s[1] \n\t"
"fmla v28.4s, v4.4s, v1.s[1] \n\t"
"fmla v29.4s, v5.4s, v1.s[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 1b \n\t"
"2: \n\t"
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t"
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t"
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t"
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
: [lhs] "+r"(lhs), [rhs] "+r"(rhs), [c] "+r"(output), [kc1] "+r"(kc1)
: [step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
} }
#else #else
void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output, void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output,
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef __ARM_NEON__ #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#ifdef _OPENMP #ifdef _OPENMP
...@@ -29,19 +29,14 @@ inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { ...@@ -29,19 +29,14 @@ inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) {
return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
} }
void pack_lhs_12r(const int m, const int k, const float *A, const int lda,
float *output, const bool parallel) {
// TODO(hjchen2)
}
void pack_lhs_6r(const int m, const int k, const float *A, const int lda, void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
float *output, const bool parallel) { float *output, const bool unroll) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5};
int remain_k = k & 0x3; int remain_k = k & 0x3;
uint32x4_t vzero = vdupq_n_u32(0); uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k)); uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k));
#pragma omp parallel for if (parallel) #pragma omp parallel for if (unroll)
for (int i = 0; i < m - 5; i += 6) { for (int i = 0; i < m - 5; i += 6) {
const float *a0 = A + i * lda; const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
...@@ -307,298 +302,316 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda, ...@@ -307,298 +302,316 @@ void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
} }
} }
void pack_rhs_8c(const int k, const int n, const float *B, const int ldb, #if __aarch64__
float *output, const bool parallel) { void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output,
#pragma omp parallel for if (parallel) const bool unroll) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7};
uint32_t remain_n = n & 0x7;
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n));
#pragma omp parallel for if (unroll)
for (int i = 0; i < k - 3; i += 4) { for (int i = 0; i < k - 3; i += 4) {
const float *b0 = B + i * ldb;
const float *b1 = b0 + ldb;
const float *b2 = b1 + ldb;
const float *b3 = b2 + ldb;
int j = 0; int j = 0;
asm volatile(
"prfm pldl1keep, [%[b0]] \n"
"prfm pldl1keep, [%[b1]] \n"
"prfm pldl1keep, [%[b2]] \n"
"prfm pldl1keep, [%[b3]] \n"
:
: [b0] "r"(b0), [b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3));
for (; j < n - 15; j += 16) { for (; j < n - 15; j += 16) {
float *out_ptr0 = output + j * k + 8 * i; float *out_ptr0 = output + j * k + 16 * i;
float *out_ptr1 = out_ptr0 + 8 * k;
const float *b0 = B + i * ldb + j;
const float *b1 = b0 + ldb;
const float *b2 = b1 + ldb;
const float *b3 = b2 + ldb;
#if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]], #64 \n"
"prfm pldl1keep, [%[b1]] \n" "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b1]], #64 \n"
"prfm pldl1keep, [%[b2]] \n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n"
"prfm pldl1keep, [%[b3]] \n" "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[out_ptr0]], #64 \n"
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b2]], #64 \n"
"ld1 {v2.4s, v3.4s}, [%[b0]], #32 \n" "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b3]], #64 \n"
"ld1 {v4.4s, v5.4s}, [%[b1]], #32 \n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n"
"ld1 {v6.4s, v7.4s}, [%[b1]], #32 \n" "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[out_ptr0]], #64 \n"
"ld1 {v8.4s, v9.4s}, [%[b2]], #32 \n" : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1),
"ld1 {v10.4s, v11.4s}, [%[b2]], #32 \n" [b2] "+r"(b2), [b3] "+r"(b3)
"ld1 {v12.4s, v13.4s}, [%[b3]], #32 \n" :
"ld1 {v14.4s, v15.4s}, [%[b3]], #32 \n" : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
}
"st1 {v0.4s, v1.4s}, [%[out_ptr0]], #32 \n" for (; j < n - 7; j += 8) {
"st1 {v4.4s, v5.4s}, [%[out_ptr0]], #32 \n" float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
"st1 {v8.4s, v9.4s}, [%[out_ptr0]], #32 \n" int step = 64;
"st1 {v12.4s, v13.4s}, [%[out_ptr0]], #32 \n" asm volatile(
"st1 {v2.4s, v3.4s}, [%[out_ptr1]], #32 \n" "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
"st1 {v6.4s, v7.4s}, [%[out_ptr1]], #32 \n" "ld1 {v2.4s, v3.4s}, [%[b1]], #32 \n"
"st1 {v10.4s, v11.4s}, [%[out_ptr1]], #32 \n" "ld1 {v4.4s, v5.4s}, [%[b2]], #32 \n"
"st1 {v14.4s, v15.4s}, [%[out_ptr1]], #32 \n" "ld1 {v6.4s, v7.4s}, [%[b3]], #32 \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0)
"st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v2.4s, v3.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v4.4s, v5.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v6.4s, v7.4s}, [%[out_ptr0]], %[step] \n"
: [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1),
[b2] "+r"(b2), [b3] "+r"(b3)
: [step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
}
if (j < n) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
int step = 64;
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"ld1 {v2.4s, v3.4s}, [%[b1]] \n"
"ld1 {v4.4s, v5.4s}, [%[b2]] \n"
"ld1 {v6.4s, v7.4s}, [%[b3]] \n"
"and v0.16b, v0.16b, %[vmask1].16b \n"
"and v1.16b, v1.16b, %[vmask2].16b \n"
"and v2.16b, v2.16b, %[vmask1].16b \n"
"and v3.16b, v3.16b, %[vmask2].16b \n"
"and v4.16b, v4.16b, %[vmask1].16b \n"
"and v5.16b, v5.16b, %[vmask2].16b \n"
"and v6.16b, v6.16b, %[vmask1].16b \n"
"and v7.16b, v7.16b, %[vmask2].16b \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v2.4s, v3.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v4.4s, v5.4s}, [%[out_ptr0]], %[step] \n"
"st1 {v6.4s, v7.4s}, [%[out_ptr0]], %[step] \n"
: [out_ptr0] "+r"(out_ptr0)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0),
[b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3), [step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
j += 8;
}
if (j & 0xf) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
out_ptr0 += 16;
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
out_ptr0 += 16;
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
out_ptr0 += 16;
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
}
}
// remain k
for (int i = (k & 0xFFFC); i < k; ++i) {
const float *b0 = B + i * ldb;
int j = 0;
asm volatile("prfm pldl1keep, [%[b0]] \n"
:
: [b0] "r"(b0));
for (; j < n - 15; j += 16) {
float *out_ptr0 = output + j * k + 16 * i;
asm volatile(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]], #64 \n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n"
: [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0)
: :
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
"v9", "v10", "v11", "v12", "v13", "v14", "v15"); }
for (; j < n - 7; j += 8) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
int step = 64;
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n"
: [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0)
: [step] "r"(step)
: "memory", "v0", "v1");
}
if (j < n) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"and v0.16b, v0.16b, %[vmask1].16b \n"
"and v1.16b, v1.16b, %[vmask2].16b \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n"
: [out_ptr0] "+r"(out_ptr0)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0)
: "memory", "v0", "v1");
j += 8;
}
if (j & 0xf) {
float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF);
vst1q_f32(out_ptr0, vzero);
vst1q_f32(out_ptr0 + 4, vzero);
}
}
}
#else #else
void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output,
const bool unroll) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7};
uint32_t remain_n = n & 0x7;
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n));
#pragma omp parallel for if (unroll)
for (int i = 0; i < k - 3; i += 4) {
const float *b0 = B + i * ldb;
const float *b1 = b0 + ldb;
const float *b2 = b1 + ldb;
const float *b3 = b2 + ldb;
int j = 0;
for (; j < n - 15; j += 16) {
float *out_ptr0 = output + j * k + 8 * i;
float *out_ptr1 = out_ptr0 + 8 * k;
asm volatile( asm volatile(
// "pld [%[b]] \n"
"vld1.32 {q0, q1}, [%[b0]]! \n" "vld1.32 {q0, q1}, [%[b0]]! \n"
"vld1.32 {q4, q5}, [%[b1]]! \n" "vld1.32 {q2, q3}, [%[b1]]! \n"
"vld1.32 {q2, q3}, [%[b0]]! \n" "vld1.32 {q4, q5}, [%[b0]]! \n"
"vld1.32 {q6, q7}, [%[b1]]! \n" "vld1.32 {q6, q7}, [%[b1]]! \n"
"vld1.32 {q8, q9}, [%[b2]]! \n" "vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vld1.32 {q12, q13}, [%[b3]]! \n" "vst1.32 {q2, q3}, [%[out_ptr0]]! \n"
"vld1.32 {q10, q11}, [%[b2]]! \n" "vst1.32 {q4, q5}, [%[out_ptr1]]! \n"
"vld1.32 {q14, q15}, [%[b3]]! \n" "vst1.32 {q6, q7}, [%[out_ptr1]]! \n"
"vld1.32 {q0, q1}, [%[b2]]! \n"
"vld1.32 {q2, q3}, [%[b3]]! \n"
"vld1.32 {q4, q5}, [%[b2]]! \n"
"vld1.32 {q6, q7}, [%[b3]]! \n"
"vst1.32 {q0, q1}, [%[out_ptr0]]! \n" "vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vst1.32 {q4, q5}, [%[out_ptr0]]! \n" "vst1.32 {q2, q3}, [%[out_ptr0]]! \n"
"vst1.32 {q8, q9}, [%[out_ptr0]]! \n" "vst1.32 {q4, q5}, [%[out_ptr1]]! \n"
"vst1.32 {q12, q13}, [%[out_ptr0]]! \n"
"vst1.32 {q2, q3}, [%[out_ptr1]]! \n"
"vst1.32 {q6, q7}, [%[out_ptr1]]! \n" "vst1.32 {q6, q7}, [%[out_ptr1]]! \n"
"vst1.32 {q10, q11}, [%[out_ptr1]]! \n"
"vst1.32 {q14, q15}, [%[out_ptr1]]! \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0), : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0),
[b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3) [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3)
: :
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
"q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif // __aarch64__
} }
for (; j < n - 7; j += 8) { for (; j < n - 7; j += 8) {
float *out_ptr0 = output + j * k + 8 * i; float *out_ptr0 = output + j * k + 8 * i;
const float *b0 = B + i * ldb + j;
const float *b1 = b0 + ldb;
const float *b2 = b1 + ldb;
const float *b3 = b2 + ldb;
#if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n" "vld1.32 {q0, q1}, [%[b0]]! \n"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n" "vld1.32 {q2, q3}, [%[b1]]! \n"
"ld1 {v2.4s, v3.4s}, [%[b1]] \n" "vld1.32 {q4, q5}, [%[b2]]! \n"
"ld1 {v4.4s, v5.4s}, [%[b2]] \n" "vld1.32 {q6, q7}, [%[b3]]! \n"
"ld1 {v6.4s, v7.4s}, [%[b3]] \n" "vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vst1.32 {q2, q3}, [%[out_ptr0]]! \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]], #32 \n" "vst1.32 {q4, q5}, [%[out_ptr0]]! \n"
"st1 {v2.4s, v3.4s}, [%[out_ptr0]], #32 \n" "vst1.32 {q6, q7}, [%[out_ptr0]]! \n"
"st1 {v4.4s, v5.4s}, [%[out_ptr0]], #32 \n"
"st1 {v6.4s, v7.4s}, [%[out_ptr0]], #32 \n"
: [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1),
[b2] "+r"(b2), [b3] "+r"(b3) [b2] "+r"(b2), [b3] "+r"(b3)
: :
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "r0"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#else }
if (j < n) {
float *out_ptr0 = output + j * k + 8 * i;
asm volatile( asm volatile(
// "pld [%[b]] \n" "vld1.32 {q0, q1}, [%[b0]] \n"
"vld1.32 {q0, q1}, [%[b0]] \n" "vld1.32 {q2, q3}, [%[b1]] \n"
"vld1.32 {q2, q3}, [%[b1]] \n" "vld1.32 {q4, q5}, [%[b2]] \n"
"vld1.32 {q4, q5}, [%[b2]] \n" "vld1.32 {q6, q7}, [%[b3]] \n"
"vld1.32 {q6, q7}, [%[b3]] \n" "vand q0, q0, %q[vmask1] \n"
"vand q1, q1, %q[vmask2] \n"
"vst1.32 {q0, q1}, [%[out_ptr0]]! \n" "vand q2, q2, %q[vmask1] \n"
"vst1.32 {q2, q3}, [%[out_ptr0]]! \n" "vand q3, q3, %q[vmask2] \n"
"vst1.32 {q4, q5}, [%[out_ptr0]]! \n" "vand q4, q4, %q[vmask1] \n"
"vst1.32 {q6, q7}, [%[out_ptr0]]! \n" "vand q5, q5, %q[vmask2] \n"
: [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), "vand q6, q6, %q[vmask1] \n"
[b2] "+r"(b2), [b3] "+r"(b3) "vand q7, q7, %q[vmask2] \n"
:
"vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vst1.32 {q2, q3}, [%[out_ptr0]]! \n"
"vst1.32 {q4, q5}, [%[out_ptr0]]! \n"
"vst1.32 {q6, q7}, [%[out_ptr0]]! \n"
: [out_ptr0] "+r"(out_ptr0)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0),
[b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#endif // __aarch64__
} }
} }
// remain k
int remain_k_start = k & 0xFFFC; for (int i = (k & 0xFFFC); i < k; ++i) {
if (remain_k_start < k) { const float *b0 = B + i * ldb;
int j = 0; int j = 0;
for (; j < n - 15; j += 16) { for (; j < n - 15; j += 16) {
float *out_ptr0 = output + j * k + 8 * remain_k_start; float *out_ptr0 = output + j * k + 8 * i;
float *out_ptr1 = out_ptr0 + 8 * k; float *out_ptr1 = out_ptr0 + 8 * k;
const float *b0 = B + remain_k_start * ldb + j;
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b0]] \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n"
"st1 {v2.4s, v3.4s}, [%[out_ptr1]] \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0)
:
: "memory", "v0", "v1", "v2", "v3");
#else
asm volatile( asm volatile(
// "pld [%[b]] \n" "vld1.32 {q0, q1}, [%[b0]]! \n"
"vld1.32 {q0, q1}, [%[b0]]! \n" "vld1.32 {q2, q3}, [%[b0]]! \n"
"vld1.32 {q2, q3}, [%[b0]] \n" "vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vst1.32 {q0, q1}, [%[out_ptr0]] \n" "vst1.32 {q2, q3}, [%[out_ptr1]]! \n"
"vst1.32 {q2, q3}, [%[out_ptr1]] \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0) : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0)
: :
: "memory", "q0", "q1", "q2", "q3"); : "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
} }
for (; j < n - 7; j += 8) { for (; j < n - 7; j += 8) {
float *out_ptr0 = output + j * k + 8 * remain_k_start; float *out_ptr0 = output + j * k + 8 * i;
const float *b0 = B + remain_k_start * ldb + j;
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n"
: [out_ptr0] "+r"(out_ptr0)
: [b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile( asm volatile(
// "pld [%[b]] \n" "vld1.32 {q0, q1}, [%[b0]]! \n"
"vld1.32 {q0, q1}, [%[b0]] \n" "vst1.32 {q0, q1}, [%[out_ptr0]]! \n"
"vst1.32 {q0, q1}, [%[out_ptr0]] \n" : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0)
: [out_ptr0] "+r"(out_ptr0) :
: [b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif // __aarch64__
} }
} if (j < n) {
float *out_ptr0 = output + j * k + 8 * i;
int remain_n_start = n & 0xFFF8;
if (remain_n_start < n) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7};
uint32_t remain_n = n & 0x7;
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n));
uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n));
float *out_ptr = output + remain_n_start * k;
for (int i = 0; i < k; ++i) {
const float *b0 = B + i * ldb + remain_n_start;
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n"
"bif v0.8b, %[vzero].8b, %[vmask1].8b \n"
"bif v1.8b, %[vzero].8b, %[vmask2].8b \n"
"st1 {v0.4s, v1.4s}, [%[out_ptr]], #32 \n"
: [out_ptr] "+r"(out_ptr)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero),
[b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile( asm volatile(
"vld1.32 {q0, q1}, [%[b0]] \n" "vld1.32 {q0, q1}, [%[b0]] \n"
"vbif q0, %q[vzero], %q[vmask1] \n" "vand q0, q0, %q[vmask1] \n"
"vbif q1, %q[vzero], %q[vmask2] \n" "vand q1, q1, %q[vmask2] \n"
"vst1.32 {q0, q1}, [%[out_ptr]] \n" "vst1.32 {q0, q1}, [%[out_ptr0]] \n"
: [out_ptr] "+r"(out_ptr) : [out_ptr0] "+r"(out_ptr0)
: [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0)
[b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif
} }
} }
} }
#endif // __aarch64__
#if __aarch64__
void write_back(const int mc, const int nc, const float *c, const int ldc1, void write_back(const int mc, const int nc, const float *c, const int ldc1,
float *C, const int ldc2) { float *C, const int ldc2) {
/* int nc1 = nc / 4;
int remain_n = nc & 0x3; int _nc1 = nc % 4;
//#ifndef __aarch64__
// register float32x4_t _in00 __asm("q0"); const float *c_ptr;
// register float32x4_t _in01 __asm("q1"); float *C_ptr;
// register float32x4_t _in10 __asm("q2"); float32x4_t cv;
// register float32x4_t _in11 __asm("q3"); for (int i = 0; i < mc; ++i) {
//#endif c_ptr = c + i * ldc1;
C_ptr = C + i * ldc2;
int m = 0; for (int j = 0; j < nc1; ++j) {
for (; m < mc - 1; m += 2) { cv = vld1q_f32(c_ptr);
const float *in0 = c + m * ldc1; vst1q_f32(C_ptr, cv);
const float *in1 = in0 + ldc1; c_ptr += 4;
float *out0 = C + m * ldc2; C_ptr += 4;
float *out1 = out0 + ldc2;
int n = 0;
for (; n < nc - 7; n += 8) {
float32x4_t _in00 = vld1q_f32(in0 + n);
float32x4_t _in01 = vld1q_f32(in0 + n + 4);
float32x4_t _in10 = vld1q_f32(in1 + n);
float32x4_t _in11 = vld1q_f32(in1 + n + 4);
vst1q_f32(out0 + n, _in00);
vst1q_f32(out0 + n + 4, _in01);
vst1q_f32(out1 + n, _in10);
vst1q_f32(out1 + n + 4, _in11);
}
for (; n < nc - 3; n += 4) {
float32x4_t _in00 = vld1q_f32(in0 + n);
float32x4_t _in10 = vld1q_f32(in1 + n);
vst1q_f32(out0 + n, _in00);
vst1q_f32(out1 + n, _in10);
}
if (n < nc) {
float32x4_t _in00 = vld1q_f32(in0 + n);
float32x4_t _in10 = vld1q_f32(in1 + n);
switch (remain_n) {
case 3:
vst1_f32(out0 + n, vget_low_f32(_in00));
vst1q_lane_f32(out0 + n + 2, _in00, 2);
vst1_f32(out1 + n, vget_low_f32(_in10));
vst1q_lane_f32(out1 + n + 2, _in10, 2);
break;
case 2:
vst1_f32(out0 + n, vget_low_f32(_in00));
vst1_f32(out1 + n, vget_low_f32(_in10));
break;
case 1:
vst1q_lane_f32(out0 + n, _in00, 2);
vst1q_lane_f32(out1 + n, _in10, 2);
break;
default:
break;
}
}
} }
if (_nc1 != 0) {
for (; m < mc; ++m) { cv = vld1q_f32(c_ptr);
const float *in0 = c + m * ldc1; if (_nc1 >= 1) {
float *out0 = C + m * ldc2; vst1q_lane_f32(C_ptr, cv, 0);
int n = 0; C_ptr++;
for (; n < nc - 7; n += 8) {
float32x4_t _in0 = vld1q_f32(in0 + n);
float32x4_t _in1 = vld1q_f32(in0 + n + 4);
vst1q_f32(out0 + n, _in0);
vst1q_f32(out0 + n + 4, _in1);
} }
for (; n < nc - 3; n += 4) { if (_nc1 >= 2) {
float32x4_t _in0 = vld1q_f32(in0 + n); vst1q_lane_f32(C_ptr, cv, 1);
vst1q_f32(out0 + n, _in0); C_ptr++;
} }
if (n < nc) { if (_nc1 >= 3) {
float32x4_t _in0 = vld1q_f32(in0 + n); vst1q_lane_f32(C_ptr, cv, 2);
switch (remain_n) {
case 3:
vst1_f32(out0 + n, vget_low_f32(_in0));
vst1q_lane_f32(out0 + n + 2, _in0, 2);
break;
case 2:
vst1_f32(out0 + n, vget_low_f32(_in0));
break;
case 1:
vst1q_lane_f32(out0 + n, _in0, 2);
break;
default:
break;
}
} }
} }
*/ }
}
#else
void write_back(const int mc, const int nc, const float *c, const int ldc1,
float *C, const int ldc2) {
int nc1 = nc / 16; int nc1 = nc / 16;
int nc2 = nc % 16; int nc2 = nc % 16;
int step1 = 4 * (ldc1 - 16 * nc1); int step1 = 4 * (ldc1 - 16 * nc1);
...@@ -650,6 +663,7 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1, ...@@ -650,6 +663,7 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1,
} }
} }
} }
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -39,23 +39,22 @@ struct SgemmStrategy { ...@@ -39,23 +39,22 @@ struct SgemmStrategy {
kernelFunc kernel; kernelFunc kernel;
WriteFunc write; WriteFunc write;
static int out_width() { return 8; } static int out_width() {
#if __aarch64__
static int out_height() { return 16;
#ifdef __aarch64__
return 12;
#else #else
return 6; return 8;
#endif #endif
} }
static int out_height() { return 6; }
SgemmStrategy() { SgemmStrategy() {
#ifdef __aarch64__
pack_lhs = pack_lhs_12r;
pack_rhs = pack_rhs_8c;
kernel = sgemm_12x8;
#else
pack_lhs = pack_lhs_6r; pack_lhs = pack_lhs_6r;
#if __aarch64__
pack_rhs = pack_rhs_16c;
kernel = sgemm_6x16;
#else
pack_rhs = pack_rhs_8c; pack_rhs = pack_rhs_8c;
kernel = sgemm_6x8; kernel = sgemm_6x8;
#endif #endif
...@@ -74,7 +73,7 @@ struct I8o32gemmStrategy { ...@@ -74,7 +73,7 @@ struct I8o32gemmStrategy {
static int out_width() { return 8; } static int out_width() { return 8; }
static int out_height() { static int out_height() {
#ifdef __aarch64__ #if __aarch64__
return 12; return 12;
#else #else
return 6; return 6;
...@@ -95,7 +94,7 @@ struct SgemvStrategy { ...@@ -95,7 +94,7 @@ struct SgemvStrategy {
static int out_width() { return 1; } static int out_width() { return 1; }
static int out_height() { static int out_height() {
#ifdef __aarch64__ #if __aarch64__
return 12; return 12;
#else #else
return 6; return 6;
...@@ -114,7 +113,7 @@ struct I8o32gemvStrategy { ...@@ -114,7 +113,7 @@ struct I8o32gemvStrategy {
static int out_width() { return 1; } static int out_width() { return 1; }
static int out_height() { static int out_height() {
#ifdef __aarch64__ #if __aarch64__
return 12; return 12;
#else #else
return 6; return 6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册