提交 ac461fd2 编写于 作者: H hjchen2

Refine gemm to support arm64

上级 79618219
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#ifdef __ARM_NEON__
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......@@ -22,10 +22,95 @@ namespace paddle_mobile {
namespace operators {
namespace math {
#ifdef __aarch64__
void sgemm_12x8(const float *lhs, const float *rhs, const int k, float *output,
#if __aarch64__
void sgemm_6x16(const float *lhs, const float *rhs, const int k, float *output,
const int ldc) {
// TODO(hjchen2)
int kc1 = k;
int step = 4 * ldc;
int step1 = 4 * 6;
asm volatile(
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"dup v29.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"prfm pldl1keep, [%[lhs], #24] \n\t"
"prfm pldl1keep, [%[rhs], #64] \n\t"
"ld1 {v0.4s, v1.4s}, [%[lhs]], %[step1] \n\t"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[rhs]], #64 \n\t"
"fmla v6.4s, v2.4s, v0.s[0] \n\t"
"fmla v7.4s, v3.4s, v0.s[0] \n\t"
"fmla v8.4s, v4.4s, v0.s[0] \n\t"
"fmla v9.4s, v5.4s, v0.s[0] \n\t"
"fmla v10.4s, v2.4s, v0.s[1] \n\t"
"fmla v11.4s, v3.4s, v0.s[1] \n\t"
"fmla v12.4s, v4.4s, v0.s[1] \n\t"
"fmla v13.4s, v5.4s, v0.s[1] \n\t"
"fmla v14.4s, v2.4s, v0.s[2] \n\t"
"fmla v15.4s, v3.4s, v0.s[2] \n\t"
"fmla v16.4s, v4.4s, v0.s[2] \n\t"
"fmla v17.4s, v5.4s, v0.s[2] \n\t"
"fmla v18.4s, v2.4s, v0.s[3] \n\t"
"fmla v19.4s, v3.4s, v0.s[3] \n\t"
"fmla v20.4s, v4.4s, v0.s[3] \n\t"
"fmla v21.4s, v5.4s, v0.s[3] \n\t"
"fmla v22.4s, v2.4s, v1.s[0] \n\t"
"fmla v23.4s, v3.4s, v1.s[0] \n\t"
"fmla v24.4s, v4.4s, v1.s[0] \n\t"
"fmla v25.4s, v5.4s, v1.s[0] \n\t"
"fmla v26.4s, v2.4s, v1.s[1] \n\t"
"fmla v27.4s, v3.4s, v1.s[1] \n\t"
"fmla v28.4s, v4.4s, v1.s[1] \n\t"
"fmla v29.4s, v5.4s, v1.s[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 1b \n\t"
"2: \n\t"
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t"
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t"
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t"
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
: [lhs] "+r"(lhs), [rhs] "+r"(rhs), [c] "+r"(output), [kc1] "+r"(kc1)
: [step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
}
#else
void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output,
......
......@@ -39,23 +39,22 @@ struct SgemmStrategy {
kernelFunc kernel;
WriteFunc write;
static int out_width() { return 8; }
static int out_height() {
#ifdef __aarch64__
return 12;
static int out_width() {
#if __aarch64__
return 16;
#else
return 6;
return 8;
#endif
}
static int out_height() { return 6; }
SgemmStrategy() {
#ifdef __aarch64__
pack_lhs = pack_lhs_12r;
pack_rhs = pack_rhs_8c;
kernel = sgemm_12x8;
#else
pack_lhs = pack_lhs_6r;
#if __aarch64__
pack_rhs = pack_rhs_16c;
kernel = sgemm_6x16;
#else
pack_rhs = pack_rhs_8c;
kernel = sgemm_6x8;
#endif
......@@ -74,7 +73,7 @@ struct I8o32gemmStrategy {
static int out_width() { return 8; }
static int out_height() {
#ifdef __aarch64__
#if __aarch64__
return 12;
#else
return 6;
......@@ -95,7 +94,7 @@ struct SgemvStrategy {
static int out_width() { return 1; }
static int out_height() {
#ifdef __aarch64__
#if __aarch64__
return 12;
#else
return 6;
......@@ -114,7 +113,7 @@ struct I8o32gemvStrategy {
static int out_width() { return 1; }
static int out_height() {
#ifdef __aarch64__
#if __aarch64__
return 12;
#else
return 6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册