Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
1ebac1c0
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ebac1c0
编写于
12月 03, 2019
作者:
T
TianXiaogang
提交者:
GitHub
12月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Armv8 4x4 gemm (#2528)
* feat: add sgemm4x4 for armv8 * fix: fix armv7 gemm choose condition
上级
04d2b4eb
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
625 addition
and
19 deletion
+625
-19
lite/backends/arm/math/packed_sgemm.cc
lite/backends/arm/math/packed_sgemm.cc
+625
-19
未找到文件。
lite/backends/arm/math/packed_sgemm.cc
浏览文件 @
1ebac1c0
...
...
@@ -53,6 +53,38 @@ void sgemm_prepacked_8x12(bool is_transB,
bool
has_bias
,
bool
has_relu
,
ARMContext
*
ctx
);
void
pack_m4
(
float
*
out
,
const
float
*
in
,
float
alpha
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
);
void
pack_trans_m4
(
float
*
out
,
const
float
*
in
,
float
alpha
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
);
void
sgemm_prepacked_4x4
(
bool
is_transB
,
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
const
float
*
bias
,
bool
has_bias
,
bool
has_relu
,
ARMContext
*
ctx
);
#else
// for kA72
void
prepackA_6x8
(
float
*
out
,
...
...
@@ -139,13 +171,21 @@ void prepackA(float *out,
bool
is_trans
,
ARMContext
*
ctx
)
{
#ifdef __aarch64__
if
(
mmax
<=
4
)
{
if
(
is_trans
)
{
pack_trans_m4
(
out
,
in
,
alpha
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
else
{
pack_m4
(
out
,
in
,
alpha
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
}
else
{
if
(
is_trans
)
{
prepackA_trans_8x12
(
out
,
in
,
alpha
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
else
{
prepackA_8x12
(
out
,
in
,
alpha
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
}
#else
if
(
ctx
->
arch
()
==
kA73
)
{
if
(
ctx
->
arch
()
==
kA73
||
mmax
<=
4
)
{
if
(
is_trans
)
{
prepackA_trans_4x8
(
out
,
in
,
alpha
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
else
{
...
...
@@ -212,6 +252,22 @@ void sgemm_prepack(bool is_transB,
bool
has_relu
,
ARMContext
*
ctx
)
{
#ifdef __aarch64__
if
(
M
<=
4
)
{
sgemm_prepacked_4x4
(
is_transB
,
M
,
N
,
K
,
A_packed
,
B
,
ldb
,
beta
,
C
,
ldc
,
bias
,
has_bias
,
has_relu
,
ctx
);
}
else
{
sgemm_prepacked_8x12
(
is_transB
,
M
,
N
,
...
...
@@ -226,8 +282,9 @@ void sgemm_prepack(bool is_transB,
has_bias
,
has_relu
,
ctx
);
}
#else // armv7
if
(
ctx
->
arch
()
==
kA73
)
{
if
(
ctx
->
arch
()
==
kA73
||
M
<=
4
)
{
sgemm_prepacked_4x8
(
is_transB
,
M
,
N
,
...
...
@@ -522,6 +579,147 @@ void prepackA_8x12(float *dout,
}
}
}
void
pack_m4
(
float
*
dout
,
const
float
*
inptr
,
float
alpha
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
)
{
int
x_len
=
kmax
-
k0
;
int
stride
=
x_len
*
4
;
float
zerobuff
[
x_len
];
// NOLINT
memset
(
zerobuff
,
0
,
sizeof
(
float
)
*
x_len
);
bool
has_alpha
=
fabsf
(
alpha
-
1.
f
)
>
1e-8
f
;
#pragma omp parallel for
for
(
int
y
=
m0
;
y
<
mmax
;
y
+=
4
)
{
float
*
outptr
=
dout
+
stride
*
(
y
-
m0
)
/
4
;
const
float
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
asm
volatile
(
"prfm pldl1keep, [%[ptr0]]
\n
"
"prfm pldl1keep, [%[ptr0], #64]
\n
"
"prfm pldl1keep, [%[ptr1]]
\n
"
"prfm pldl1keep, [%[ptr1], #64]
\n
"
"prfm pldl1keep, [%[ptr2]]
\n
"
"prfm pldl1keep, [%[ptr2], #64]
\n
"
"prfm pldl1keep, [%[ptr3]]
\n
"
"prfm pldl1keep, [%[ptr3], #64]
\n
"
:
:
[
ptr0
]
"r"
(
inptr0
),
[
ptr1
]
"r"
(
inptr1
),
[
ptr2
]
"r"
(
inptr2
),
[
ptr3
]
"r"
(
inptr3
)
:
"memory"
);
int
x
=
x_len
;
//! cope with row index exceed real size, set to zero buffer
if
((
y
+
3
)
>=
mmax
)
{
switch
((
y
+
3
)
-
mmax
)
{
case
2
:
inptr1
=
zerobuff
;
case
1
:
inptr2
=
zerobuff
;
case
0
:
inptr3
=
zerobuff
;
default:
break
;
}
}
for
(;
x
>
7
;
x
-=
8
)
{
asm
volatile
(
"cbz %w[has_alpha], 0f
\n
"
/* check alpha == 1.f? */
"dup v31.4s, %w[alpha]
\n
"
/* alpha to vector */
"ldp q0, q1, [%[inptr0]], #32
\n
"
/* load r0, a0~a7 */
"ldp q2, q3, [%[inptr1]], #32
\n
"
/* load r1, b0~b7 */
"fmul v0.4s, v31.4s, v0.4s
\n
"
/* mul alpha */
"fmul v1.4s, v31.4s, v1.4s
\n
"
/* mul alpha */
"ldp q4, q5, [%[inptr2]], #32
\n
"
/* load r2, c0~c7 */
"fmul v2.4s, v31.4s, v2.4s
\n
"
/* mul alpha */
"fmul v3.4s, v31.4s, v3.4s
\n
"
/* mul alpha */
"ldp q6, q7, [%[inptr3]], #32
\n
"
/* load r3, d0~d7 */
"fmul v4.4s, v31.4s, v4.4s
\n
"
/* mul alpha */
"fmul v5.4s, v31.4s, v5.4s
\n
"
/* mul alpha */
"fmul v6.4s, v31.4s, v6.4s
\n
"
/* mul alpha */
"fmul v7.4s, v31.4s, v7.4s
\n
"
/* mul alpha */
"b 1f
\n
"
/* to main process */
"0:
\n
"
/* alpha == 1 */
"ldp q0, q1, [%[inptr0]], #32
\n
"
/* load r0, a0~a7 */
"ldp q2, q3, [%[inptr1]], #32
\n
"
/* load r1, b0~b7 */
"ldp q4, q5, [%[inptr2]], #32
\n
"
/* load r2, c0~c7 */
"ldp q6, q7, [%[inptr3]], #32
\n
"
/* load r3, d0~d7 */
"1:
\n
"
/* main process */
"trn1 v8.4s, v0.4s, v2.4s
\n
"
/* a0b0a2b2*/
"trn2 v9.4s, v0.4s, v2.4s
\n
"
/* a1b1a3b3*/
"trn1 v10.4s, v1.4s, v3.4s
\n
"
/* a4b4a6b6*/
"trn2 v11.4s, v1.4s, v3.4s
\n
"
/* a5b5a7b7*/
"trn1 v12.4s, v4.4s, v6.4s
\n
"
/* c0d0c2d2*/
"trn2 v13.4s, v4.4s, v6.4s
\n
"
/* c1d1c3d3*/
"trn1 v14.4s, v5.4s, v7.4s
\n
"
/* c4d4c6d6*/
"trn2 v15.4s, v5.4s, v7.4s
\n
"
/* c5d5c7d7*/
"trn1 v0.2d, v8.2d, v12.2d
\n
"
/* a0b0c0d0 */
"trn1 v1.2d, v9.2d, v13.2d
\n
"
/* a1b1c1d1 */
"trn1 v2.2d, v10.2d, v14.2d
\n
"
/* a4b4c4d4 */
"trn1 v3.2d, v11.2d, v15.2d
\n
"
/* a5b5c5d5 */
"trn2 v4.2d, v8.2d, v12.2d
\n
"
/* a2b2c2d2 */
"trn2 v5.2d, v9.2d, v13.2d
\n
"
/* a3b3c3d3 */
"stp q0, q1, [%[outptr]], #32
\n
"
/* save q0, q1, a0~h0*/
"trn2 v6.2d, v10.2d, v14.2d
\n
"
/* a6b6c6d6 */
"trn2 v7.2d, v11.2d, v15.2d
\n
"
/* a7b7c7d7 */
"stp q4, q5, [%[outptr]], #32
\n
"
/* save q2, q3, a1~h1*/
"stp q2, q3, [%[outptr]], #32
\n
"
/* save q4, q5, a2~h2*/
"stp q6, q7, [%[outptr]], #32
\n
"
/* save q6, q7, a3~h3*/
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
outptr
]
"+r"
(
outptr
)
:
[
alpha
]
"r"
(
alpha
),
[
has_alpha
]
"r"
(
has_alpha
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"cc"
,
"memory"
);
}
for
(;
x
>
0
;
x
--
)
{
if
(
has_alpha
)
{
*
outptr
++
=
*
inptr0
++
*
alpha
;
*
outptr
++
=
*
inptr1
++
*
alpha
;
*
outptr
++
=
*
inptr2
++
*
alpha
;
*
outptr
++
=
*
inptr3
++
*
alpha
;
}
else
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
*
outptr
++
=
*
inptr2
++
;
*
outptr
++
=
*
inptr3
++
;
}
}
}
}
void
prepackA_trans_8x12
(
float
*
outptr
,
const
float
*
in
,
...
...
@@ -682,6 +880,128 @@ void prepackA_trans_8x12(float *outptr,
}
}
}
void
pack_trans_m4
(
float
*
outptr
,
const
float
*
in
,
float
alpha
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
)
{
auto
inptr
=
in
+
k0
*
ldin
+
m0
;
uint32_t
mask_buffer
[
4
]
=
{
0
,
1
,
2
,
3
};
int
x_len
=
mmax
-
m0
;
int
y_len
=
kmax
-
k0
;
int
right_remain
=
x_len
-
4
*
(
x_len
/
4
);
int
stride_out
=
4
*
y_len
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
uint32x4_t
vmask1
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
),
vdupq_n_u32
(
right_remain
));
bool
has_alpha
=
fabsf
(
alpha
-
1.
f
)
>
1e-8
f
;
float32x4_t
valpha
=
vdupq_n_f32
(
alpha
);
#pragma omp parallel for
for
(
int
y
=
0
;
y
<
y_len
-
3
;
y
+=
4
)
{
const
float
*
ptr0
=
inptr
+
y
*
ldin
;
const
float
*
ptr1
=
ptr0
+
ldin
;
const
float
*
ptr2
=
ptr1
+
ldin
;
const
float
*
ptr3
=
ptr2
+
ldin
;
asm
volatile
(
"prfm pldl1keep, [%[ptr0]]
\n
"
"prfm pldl1keep, [%[ptr0], #64]
\n
"
"prfm pldl1keep, [%[ptr1]]
\n
"
"prfm pldl1keep, [%[ptr1], #64]
\n
"
"prfm pldl1keep, [%[ptr2]]
\n
"
"prfm pldl1keep, [%[ptr2], #64]
\n
"
"prfm pldl1keep, [%[ptr3]]
\n
"
"prfm pldl1keep, [%[ptr3], #64]
\n
"
:
:
[
ptr0
]
"r"
(
ptr0
),
[
ptr1
]
"r"
(
ptr1
),
[
ptr2
]
"r"
(
ptr2
),
[
ptr3
]
"r"
(
ptr3
)
:
"memory"
);
float
*
outptr_row_col
=
outptr
+
y
*
4
;
int
i
=
0
;
for
(;
i
<
x_len
-
3
;
i
+=
4
)
{
float32x4_t
vr00
=
vld1q_f32
(
ptr0
);
float32x4_t
vr10
=
vld1q_f32
(
ptr1
);
float32x4_t
vr20
=
vld1q_f32
(
ptr2
);
float32x4_t
vr30
=
vld1q_f32
(
ptr3
);
if
(
has_alpha
)
{
vr00
=
vmulq_f32
(
vr00
,
valpha
);
vr10
=
vmulq_f32
(
vr10
,
valpha
);
vr20
=
vmulq_f32
(
vr20
,
valpha
);
vr30
=
vmulq_f32
(
vr30
,
valpha
);
}
vst1q_f32
(
outptr_row_col
,
vr00
);
vst1q_f32
(
outptr_row_col
+
4
,
vr10
);
vst1q_f32
(
outptr_row_col
+
8
,
vr20
);
vst1q_f32
(
outptr_row_col
+
12
,
vr30
);
ptr0
+=
4
;
ptr1
+=
4
;
ptr2
+=
4
;
ptr3
+=
4
;
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
float32x4_t
vr00
=
vld1q_f32
(
ptr0
);
float32x4_t
vr10
=
vld1q_f32
(
ptr1
);
float32x4_t
vr20
=
vld1q_f32
(
ptr2
);
float32x4_t
vr30
=
vld1q_f32
(
ptr3
);
if
(
has_alpha
)
{
vr00
=
vmulq_f32
(
vr00
,
valpha
);
vr10
=
vmulq_f32
(
vr10
,
valpha
);
vr20
=
vmulq_f32
(
vr20
,
valpha
);
vr30
=
vmulq_f32
(
vr30
,
valpha
);
}
float32x4_t
vr00_1
=
vbslq_f32
(
vmask1
,
vr00
,
vzero
);
float32x4_t
vr10_1
=
vbslq_f32
(
vmask1
,
vr10
,
vzero
);
float32x4_t
vr20_1
=
vbslq_f32
(
vmask1
,
vr20
,
vzero
);
float32x4_t
vr30_1
=
vbslq_f32
(
vmask1
,
vr30
,
vzero
);
vst1q_f32
(
outptr_row_col
,
vr00_1
);
vst1q_f32
(
outptr_row_col
+
4
,
vr10_1
);
vst1q_f32
(
outptr_row_col
+
8
,
vr20_1
);
vst1q_f32
(
outptr_row_col
+
12
,
vr30_1
);
}
}
#pragma omp parallel for
for
(
int
y
=
4
*
(
y_len
/
4
);
y
<
y_len
;
++
y
)
{
const
float
*
ptr0
=
inptr
+
y
*
ldin
;
float
*
outptr_row_col
=
outptr
+
y
*
4
;
int
i
=
0
;
for
(;
i
<
x_len
-
3
;
i
+=
4
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr0
);
if
(
has_alpha
)
{
vr0
=
vmulq_f32
(
vr0
,
valpha
);
}
vst1q_f32
(
outptr_row_col
,
vr0
);
ptr0
+=
4
;
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr0
);
if
(
has_alpha
)
{
vr0
=
vmulq_f32
(
vr0
,
valpha
);
}
float32x4_t
vr0_1
=
vbslq_f32
(
vmask1
,
vr0
,
vzero
);
vst1q_f32
(
outptr_row_col
,
vr0_1
);
}
}
}
#else // __aarch64__
void
prepackA_6x8
(
float
*
outptr
,
...
...
@@ -2592,6 +2912,292 @@ void sgemm_prepacked_8x12(bool is_transB,
}
}
}
void
sgemm_prepacked_4x4
(
bool
is_transB
,
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
,
const
float
*
bias
,
bool
has_bias
,
bool
has_relu
,
ARMContext
*
ctx
)
{
size_t
l2_cache
=
ctx
->
llc_size
()
>
0
?
ctx
->
llc_size
()
:
512
*
1024
;
auto
workspace
=
ctx
->
workspace_data
<
float
>
();
int
threads
=
ctx
->
threads
();
const
int
n_block
=
4
;
const
int
m_block
=
4
;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int
x_block
=
(
l2_cache
-
(
m_block
*
K
))
/
(
sizeof
(
float
)
*
(
K
+
m_block
));
x_block
/=
n_block
;
x_block
*=
n_block
;
int
x_num
=
(
N
+
(
x_block
-
1
))
/
x_block
;
x_block
=
(
N
+
x_num
-
1
)
/
x_num
;
x_block
=
(
x_block
+
n_block
-
1
)
/
n_block
;
x_block
*=
n_block
;
x_block
=
x_block
<
n_block
?
n_block
:
x_block
;
// unroll 2 loop
int
tail_pre
=
(
K
&
(
KBLOCK
-
1
));
int
k_pre
=
((
K
+
KBLOCK
-
1
)
/
KBLOCK
)
-
1
;
if
(
tail_pre
==
0
)
{
tail_pre
=
KBLOCK
;
}
bool
flag_p_remain
=
false
;
int
remain
=
0
;
int
has_beta
=
fabsf
(
beta
)
>
1e-8
f
?
1
:
0
;
//! apanel is pre_compute outside gemm
for
(
unsigned
int
x0
=
0
;
x0
<
N
;
x0
+=
x_block
)
{
unsigned
int
xmax
=
x0
+
x_block
;
if
(
xmax
>
N
)
{
xmax
=
N
;
}
int
bblocks
=
(
xmax
-
x0
+
n_block
-
1
)
/
n_block
;
remain
=
xmax
-
x0
-
(
bblocks
-
1
)
*
n_block
;
if
(
remain
>
0
)
{
flag_p_remain
=
true
;
}
//! load bpanel
float
*
b_pannel
=
workspace
;
if
(
is_transB
)
{
pack_m4
(
b_pannel
,
B
,
1.0
f
,
ldb
,
x0
,
xmax
,
0
,
K
);
}
else
{
pack_trans_m4
(
b_pannel
,
B
,
1.0
f
,
ldb
,
x0
,
xmax
,
0
,
K
);
}
#pragma omp parallel for num_threads(threads)
for
(
unsigned
int
y
=
0
;
y
<
M
;
y
+=
m_block
)
{
unsigned
int
ymax
=
y
+
m_block
;
if
(
ymax
>
M
)
{
ymax
=
M
;
}
float
bias_local
[
4
]
=
{
0
};
if
(
has_bias
)
{
bias_local
[
0
]
=
bias
[
y
];
bias_local
[
1
]
=
bias
[
y
+
1
];
bias_local
[
2
]
=
bias
[
y
+
2
];
bias_local
[
3
]
=
bias
[
y
+
3
];
}
float
cout0
[
n_block
];
// NOLINT
float
cout1
[
n_block
];
// NOLINT
float
cout2
[
n_block
];
// NOLINT
float
cout3
[
n_block
];
// NOLINT
float
*
c_ptr0
=
C
+
y
*
ldc
+
x0
;
float
*
c_ptr1
=
c_ptr0
+
ldc
;
float
*
c_ptr2
=
c_ptr1
+
ldc
;
float
*
c_ptr3
=
c_ptr2
+
ldc
;
float
*
pout0
=
c_ptr0
;
float
*
pout1
=
c_ptr1
;
float
*
pout2
=
c_ptr2
;
float
*
pout3
=
c_ptr3
;
const
float
*
a_ptr_l
=
A_packed
+
y
*
K
;
const
float
*
b_ptr_l
=
b_pannel
;
for
(
int
xb
=
0
;
xb
<
bblocks
;
xb
++
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
case
2
:
c_ptr1
=
cout1
;
case
1
:
c_ptr2
=
cout2
;
case
0
:
c_ptr3
=
cout3
;
default:
break
;
}
}
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
pout0
=
c_ptr0
;
pout1
=
c_ptr1
;
pout2
=
c_ptr2
;
pout3
=
c_ptr3
;
c_ptr0
=
cout0
;
c_ptr1
=
cout1
;
c_ptr2
=
cout2
;
c_ptr3
=
cout3
;
if
(
has_beta
)
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
cout0
[
i
]
=
pout0
[
i
];
cout1
[
i
]
=
pout1
[
i
];
cout2
[
i
]
=
pout2
[
i
];
cout3
[
i
]
=
pout3
[
i
];
}
}
}
const
float
*
a_ptr
=
a_ptr_l
;
const
float
*
b_ptr
=
b_ptr_l
+
xb
*
K
*
4
;
int
tail
=
tail_pre
;
int
k
=
k_pre
;
// clang-format off
asm
volatile
(
"prfm pldl1keep, [%[a_ptr]]
\n
"
/* preload a*/
"ld1 {v2.4s}, [%[bias_ptr]]
\n
"
/* load bias to q2, q3*/
"dup v8.4s, v2.s[0]
\n
"
/* out0 = 0 */
"prfm pldl1keep, [%[b_ptr]]
\n
"
/* preload b*/
"dup v9.4s, v2.s[1]
\n
"
/* out1 = 0*/
"prfm pldl1keep, [%[a_ptr], #64]
\n
"
/* preload a*/
"dup v10.4s, v2.s[2]
\n
"
/* out2 = 0*/
"prfm pldl1keep, [%[b_ptr], #64]
\n
"
/* preload b*/
"dup v11.4s, v2.s[3]
\n
"
/* out3 = 0*/
"cbz %w[has_beta], 0f
\n
"
/* check beta == 0? */
/* process beta */
"dup v7.4s, %w[beta]
\n
"
/* beta to vector */
"ld1 {v0.4s}, [%[c_ptr0]]
\n
"
/* load output r0 */
"ld1 {v1.4s}, [%[c_ptr1]]
\n
"
/* load output r1 */
"fmla v8.4s, v0.4s, v7.4s
\n
"
/* cr00 += beta * c_r00*/
"fmla v9.4s, v1.4s, v7.4s
\n
"
/* cr10 += beta * c_r10*/
"ld1 {v2.4s}, [%[c_ptr2]]
\n
"
"ld1 {v3.4s}, [%[c_ptr3]]
\n
"
"fmla v10.4s, v2.4s, v7.4s
\n
"
/* cr20 += beta * c_r20*/
"fmla v11.4s, v3.4s, v7.4s
\n
"
/* cr30 += beta * c_r30*/
"0:
\n
"
/* check loop count */
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00,a10 to q0, q1*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q4, q5*/
"cbz %w[k], 2f
\n
"
/* check loop count > 0 */
/* main loop */
/* unrool 0*/
"1:
\n
"
/* main loop */
"fmla v8.4s, v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 =q4 */
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b2, b3 to q6, q7 */
"fmla v10.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 =q4 */
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a20, a30 to q2, q3 */
"fmla v8.4s, v5.4s, v1.s[0]
\n
"
/* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]
\n
"
/* out1 = b1 * a10[1], b1 =q5 */
"fmla v10.4s, v5.4s, v1.s[2]
\n
"
/* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]
\n
"
/* out3 = b1 * a10[3], b1 =q5 */
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q4, q5*/
"fmla v8.4s, v6.4s, v2.s[0]
\n
"
/* out0 = b2 * a20[0], b2 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]
\n
"
/* out1 = b2 * a20[1], b2 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]
\n
"
/* out2 = b2 * a20[2], b2 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]
\n
"
/* out3 = b2 * a20[3], b2 =q6*/
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00, a10 to q0, q1 */
"fmla v8.4s, v7.4s, v3.s[0]
\n
"
/* out0 = b3 * a30[0], b3 =q7*/
"fmla v9.4s, v7.4s, v3.s[1]
\n
"
/* out1 = b3 * a30[1], b3 =q7*/
"subs %w[k], %w[k], #1
\n
"
/* loop count - 1*/
"fmla v10.4s, v7.4s, v3.s[2]
\n
"
/* out2 = b3 * a30[2], b3 =q7*/
"fmla v11.4s, v7.4s, v3.s[3]
\n
"
/* out3 = b3 * a30[3], b3 =q7*/
"bne 1b
\n
"
"2:
\n
"
/* process tail*/
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"beq 3f
\n
"
/*jump to tail = 1*/
/* final unrool 0*/
/* unrool 0, tail > 1*/
"fmla v8.4s, v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 =q4 */
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"fmla v10.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 =q4 */
"beq 4f
\n
"
/*jump to tail = 2*/
/* unrool 1, tail > 2*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b2, b3 to q6, q7 */
"fmla v8.4s, v5.4s, v1.s[0]
\n
"
/* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]
\n
"
/* out1 = b1 * a10[1], b1 =q5*/
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"fmla v10.4s, v5.4s, v1.s[2]
\n
"
/* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]
\n
"
/* out3 = b1 * a10[3], b1 =q5 */
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a20, a30 to q2, q3 */
"beq 5f
\n
"
/*jump to tail = 3*/
/* unrool 2, tail = 4*/
"fmla v8.4s, v6.4s, v2.s[0]
\n
"
/* out0 = b2 * a20[0], b1 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]
\n
"
/* out1 = b2 * a20[1], b1 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]
\n
"
/* out2 = b2 * a20[2], b1 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]
\n
"
/* out3 = b2 * a20[3], b1 =q6*/
/* unrool 3, tail = 4*/
"fmla v8.4s, v7.4s, v3.s[0]
\n
"
/* out0 = b3 * a30[0], b3 =q7*/
"fmla v9.4s, v7.4s, v3.s[1]
\n
"
/* out1 = b3 * a30[1], b3 =q7*/
"fmla v10.4s, v7.4s, v3.s[2]
\n
"
/* out2 = b3 * a30[2], b3 =q7*/
"fmla v11.4s, v7.4s, v3.s[3]
\n
"
/* out3 = b3 * a30[3], b3 =q7*/
"b 11f
\n
"
/* tails==1 final tail*/
"3:
\n
"
/* tail=1*/
"fmla v8.4s, v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 =q4 */
"fmla v9.4s, v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 =q4 */
"fmla v10.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 =q4 */
"fmla v11.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 =q4 */
"b 11f
\n
"
/* tails==2 final tail*/
"4:
\n
"
/* tail = 2*/
"fmla v8.4s, v5.4s, v1.s[0]
\n
"
/* out0 = b1 * a10[0], b1 =q5 */
"fmla v9.4s, v5.4s, v1.s[1]
\n
"
/* out1 = b1 * a10[1], b1 =q5*/
"fmla v10.4s, v5.4s, v1.s[2]
\n
"
/* out2 = b1 * a10[2], b1 =q5 */
"fmla v11.4s, v5.4s, v1.s[3]
\n
"
/* out3 = b1 * a10[3], b1 =q5 */
"b 11f
\n
"
/* tails==3 final tail*/
"5:
\n
"
/* tail = 3*/
"fmla v8.4s, v6.4s, v2.s[0]
\n
"
/* out0 = b2 * a20[0], b1 =q6 */
"fmla v9.4s, v6.4s, v2.s[1]
\n
"
/* out1 = b2 * a20[1], b1 =q6 */
"fmla v10.4s, v6.4s, v2.s[2]
\n
"
/* out2 = b2 * a20[2], b1 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]
\n
"
/* out3 = b2 * a20[3], b1 =q6*/
"11:
\n
"
/* check if relu */
"cbz %w[relu], 12f
\n
"
/* skip relu */
"movi v2.4s, #0
\n
"
/* for relu*/
"fmax v8.4s, v8.4s, v2.4s
\n
"
/* relu*/
"fmax v9.4s, v9.4s, v2.4s
\n
"
/* relu*/
"fmax v10.4s, v10.4s, v2.4s
\n
"
/* relu*/
"fmax v11.4s, v11.4s, v2.4s
\n
"
/* relu*/
"12:
\n
"
"st1 {v8.4s}, [%[c_ptr0]], #16
\n
"
/* store r0 */
"st1 {v9.4s}, [%[c_ptr1]], #16
\n
"
/* store r1 */
"st1 {v10.4s}, [%[c_ptr2]], #16
\n
"
/* store r2 */
"st1 {v11.4s}, [%[c_ptr3]], #16
\n
"
/* store r3 */
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
k
]
"+r"
(
k
),
[
tail
]
"+r"
(
tail
),
[
c_ptr0
]
"+r"
(
c_ptr0
),
[
c_ptr1
]
"+r"
(
c_ptr1
),
[
c_ptr2
]
"+r"
(
c_ptr2
),
[
c_ptr3
]
"+r"
(
c_ptr3
)
:
[
bias_ptr
]
"r"
(
bias_local
),
[
relu
]
"r"
(
has_relu
),
[
has_beta
]
"r"
(
has_beta
),
[
beta
]
"r"
(
beta
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
);
// clang-format on
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
*
pout0
++
=
cout0
[
i
];
*
pout1
++
=
cout1
[
i
];
*
pout2
++
=
cout2
[
i
];
*
pout3
++
=
cout3
[
i
];
}
}
}
}
}
}
#else // __aarch64__
/**
* \brief gemm with ablock = 6, bblock = 8, output 6x8
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录