Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
3b2d3189
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3b2d3189
编写于
4月 20, 2020
作者:
Y
yiicy
提交者:
GitHub
4月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[arm] improve sgemm performance on A53, test=develop (#3439)
improve sgemm performance on A53
上级
afefe9cf
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
522 addition
and
3 deletion
+522
-3
lite/backends/arm/math/packed_sgemm.cc
lite/backends/arm/math/packed_sgemm.cc
+519
-0
lite/tests/kernels/topk_compute_test.cc
lite/tests/kernels/topk_compute_test.cc
+3
-3
未找到文件。
lite/backends/arm/math/packed_sgemm.cc
浏览文件 @
3b2d3189
...
...
@@ -72,6 +72,7 @@ void pack_trans_m4(float *out,
int
mmax
,
int
k0
,
int
kmax
);
void
sgemm_prepacked_4x4
(
bool
is_transB
,
int
M
,
int
N
,
...
...
@@ -154,6 +155,20 @@ void sgemm_prepacked_4x8(bool is_transB,
bool
has_bias
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
// for kA53
void
sgemm_prepacked_6x8_a53
(
bool
is_transB
,
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
int
ldb
,
float
*
C
,
int
ldc
,
const
float
*
bias
,
bool
has_bias
,
int
is_relu
,
ARMContext
*
ctx
);
#endif // __aarch64__
/**
...
...
@@ -300,6 +315,44 @@ void sgemm_prepack(bool is_transB,
has_bias
,
act_param
,
ctx
);
}
else
if
(
ctx
->
arch
()
==
kA53
)
{
auto
act_type
=
act_param
.
active_type
;
bool
has_act
=
act_param
.
has_active
;
bool
act_flag
=
(
has_act
==
false
)
||
(
has_act
==
true
&&
act_type
==
lite_api
::
ActivationType
::
kRelu
);
bool
has_beta
=
fabsf
(
beta
)
>
1e-8
f
?
true
:
false
;
bool
a53_sgemm
=
act_flag
&&
!
has_beta
;
if
(
a53_sgemm
)
{
sgemm_prepacked_6x8_a53
(
is_transB
,
M
,
N
,
K
,
A_packed
,
B
,
ldb
,
C
,
ldc
,
bias
,
has_bias
,
static_cast
<
int
>
(
has_act
),
ctx
);
}
else
{
sgemm_prepacked_6x8
(
is_transB
,
M
,
N
,
K
,
A_packed
,
B
,
ldb
,
beta
,
C
,
ldc
,
bias
,
has_bias
,
act_param
,
ctx
);
}
}
else
{
sgemm_prepacked_6x8
(
is_transB
,
M
,
...
...
@@ -3983,6 +4036,472 @@ void sgemm_prepacked_6x8(bool is_transB,
}
}
/**
* \brief gemm with ablock = 6, bblock = 8, output 6x8, optimize for a53 arch
* @param A
* @param B
* @param C
* @param M
* @param N
* @param K
* @param threads
* @param workspace
*/
void
sgemm_prepacked_6x8_a53
(
bool
is_transB
,
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
int
ldb
,
float
*
C
,
int
ldc
,
const
float
*
bias
,
bool
has_bias
,
int
is_relu
,
ARMContext
*
ctx
)
{
size_t
l2_cache
=
ctx
->
llc_size
()
>
0
?
ctx
->
llc_size
()
:
512
*
1024
;
auto
*
workspace
=
ctx
->
workspace_data
<
float
>
();
int
threads
=
ctx
->
threads
();
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int
x_block
=
(
l2_cache
-
(
MBLOCK_OTH
*
K
))
/
(
sizeof
(
float
)
*
(
K
+
MBLOCK_OTH
));
x_block
/=
NBLOCK
;
x_block
*=
NBLOCK
;
int
x_num
=
(
N
+
(
x_block
-
1
))
/
x_block
;
x_block
=
(
N
+
x_num
-
1
)
/
x_num
;
x_block
=
(
x_block
+
NBLOCK
-
1
)
/
NBLOCK
;
x_block
*=
NBLOCK
;
x_block
=
x_block
<
NBLOCK
?
NBLOCK
:
x_block
;
int
k_pre
=
((
K
+
KBLOCK
-
1
)
/
KBLOCK
)
-
1
;
int
tail_pre
=
(
K
&
(
KBLOCK
-
1
));
if
(
tail_pre
==
0
)
{
tail_pre
=
KBLOCK
;
}
//! merge tail_pre and flag_act
tail_pre
=
(
tail_pre
<<
2
|
is_relu
);
bool
flag_p_remain
=
false
;
int
remain
=
0
;
//! apanel is pre_compute outside gemm
for
(
unsigned
int
x0
=
0
;
x0
<
N
;
x0
+=
x_block
)
{
unsigned
int
xmax
=
x0
+
x_block
;
if
(
xmax
>
N
)
{
xmax
=
N
;
}
int
bblocks
=
(
xmax
-
x0
+
NBLOCK
-
1
)
/
NBLOCK
;
remain
=
xmax
-
x0
-
(
bblocks
-
1
)
*
NBLOCK
;
if
(
remain
>
0
)
{
flag_p_remain
=
true
;
}
//! load bpanel
auto
b_pannel
=
static_cast
<
float
*>
(
workspace
);
if
(
is_transB
)
{
loadb_trans
(
b_pannel
,
B
,
ldb
,
0
,
K
,
x0
,
xmax
);
}
else
{
loadb
(
b_pannel
,
B
,
ldb
,
0
,
K
,
x0
,
xmax
);
}
#pragma omp parallel for num_threads(threads)
for
(
unsigned
int
y
=
0
;
y
<
M
;
y
+=
MBLOCK_OTH
)
{
unsigned
int
ymax
=
y
+
MBLOCK_OTH
;
if
(
ymax
>
M
)
{
ymax
=
M
;
}
float
*
c_ptr0
=
C
+
y
*
ldc
+
x0
;
float
*
c_ptr1
=
c_ptr0
+
ldc
;
float
*
c_ptr2
=
c_ptr1
+
ldc
;
float
*
c_ptr3
=
c_ptr2
+
ldc
;
float
*
c_ptr4
=
c_ptr3
+
ldc
;
float
*
c_ptr5
=
c_ptr4
+
ldc
;
float
*
pout0
=
c_ptr0
;
float
*
pout1
=
c_ptr1
;
float
*
pout2
=
c_ptr2
;
float
*
pout3
=
c_ptr3
;
float
*
pout4
=
c_ptr4
;
float
*
pout5
=
c_ptr5
;
float
bias_local
[
6
]
=
{
0
};
if
(
has_bias
)
{
bias_local
[
0
]
=
bias
[
y
];
bias_local
[
1
]
=
bias
[
y
+
1
];
bias_local
[
2
]
=
bias
[
y
+
2
];
bias_local
[
3
]
=
bias
[
y
+
3
];
bias_local
[
4
]
=
bias
[
y
+
4
];
bias_local
[
5
]
=
bias
[
y
+
5
];
}
float
cout0
[
NBLOCK
];
float
cout1
[
NBLOCK
];
float
cout2
[
NBLOCK
];
float
cout3
[
NBLOCK
];
float
cout4
[
NBLOCK
];
float
cout5
[
NBLOCK
];
const
float
*
a_ptr_l
=
A_packed
+
y
*
K
;
const
float
*
b_ptr
=
b_pannel
;
for
(
int
xb
=
0
;
xb
<
bblocks
;
xb
++
)
{
if
((
y
+
5
)
>=
ymax
)
{
switch
((
y
+
5
)
-
ymax
)
{
case
4
:
c_ptr1
=
cout1
;
case
3
:
c_ptr2
=
cout2
;
case
2
:
c_ptr3
=
cout3
;
case
1
:
c_ptr4
=
cout4
;
case
0
:
c_ptr5
=
cout5
;
default:
break
;
}
}
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
pout0
=
c_ptr0
;
pout1
=
c_ptr1
;
pout2
=
c_ptr2
;
pout3
=
c_ptr3
;
pout4
=
c_ptr4
;
pout5
=
c_ptr5
;
c_ptr0
=
cout0
;
c_ptr1
=
cout1
;
c_ptr2
=
cout2
;
c_ptr3
=
cout3
;
c_ptr4
=
cout4
;
c_ptr5
=
cout5
;
}
const
float
*
a_ptr
=
a_ptr_l
;
int
tails
=
tail_pre
;
int
k
=
k_pre
;
// clang-format off
asm
volatile
(
// sgemm 6x8 for a53
"vld1.32 {d2-d3}, [%[bias_ptr]]
\n
"
/* load bias0-3 to d2,d3 */
"vdup.i32 q4, d2[0]
\n
"
/* set out00 to bias0 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]
\n
"
/* load a00-a30 to d0,d1 */
"vdup.i32 q5, d2[0]
\n
"
/* set out01 to bias0 */
"vld1.32 {d4-d5}, [%[b_ptr] :128]
\n
"
/* load b00-b03 to d4,d5 */
"vdup.i32 q6, d2[1]
\n
"
/* set out10 to bias1 */
"ldr r0, [%[a_ptr], #0x10]
\n
"
/* load a40 to r0 */
"vdup.i32 q7, d2[1]
\n
"
/* set out11 to bias1 */
"ldr r1, [%[a_ptr], #0x14]
\n
"
/* load a50 to r1 */
"vdup.i32 q8, d3[0]
\n
"
/* set out20 to bias2 */
"vldr d6, [%[bias_ptr], #0x10]
\n
"
/* load bias 4,5 to d6 */
"pld [%[a_ptr], #0x40]
\n
"
/* pre load apanel */
"vdup.i32 q9, d3[0]
\n
"
/* set out21 to bias2 */
"pld [%[b_ptr], #0x40]
\n
"
/* pre load bpanel */
"vdup.i32 q10, d3[1]
\n
"
/* set out30 to bias3 */
"pld [%[a_ptr], #0x80]
\n
"
/* pre load apanel */
"vdup.i32 q11, d3[1]
\n
"
/* set out31 to bias3 */
"pld [%[b_ptr], #0x80]
\n
"
/* pre load bpanel */
"vdup.i32 q12, d6[0]
\n
"
/* set out40 to bias4 */
"vdup.i32 q13, d6[0]
\n
"
/* set out41 to bias4 */
"pld [%[a_ptr], #0xC0]
\n
"
/* pre load apanel */
"vdup.i32 q14, d6[1]
\n
"
/* set out50 to bias5 */
"pld [%[b_ptr], #0XC0]
\n
"
/* pre load bpanel */
"vdup.i32 q15, d6[1]
\n
"
/* set out51 to bias5 */
"cmp %[k], #0
\n
"
/* check k loop */
"beq 6f
\n
"
/* k==0, branch to 6 */
"1:
\n
"
/* Unroll 0 */
"vldr d6, [%[b_ptr], #0x10]
\n
"
/* load b04, b05 to d6 */
"vmov d2, r0, r1
\n
"
/* mov a40, a50 to d2 */
"vmla.f32 q4, q2, d0[0]
\n
"
/* out00 += a00 * b0l */
"ldr r0, [%[b_ptr], #0x18]
\n
"
/* load b06 to r0 */
"vmla.f32 q6, q2, d0[1]
\n
"
/* out10 += a10 * b0l */
"ldr r1, [%[b_ptr], #0x1C]
\n
"
/* load b07 to r1 */
"vmla.f32 q8, q2, d1[0]
\n
"
/* out20 += a20 * b0l */
"vldr d3, [%[a_ptr], #0x18]
\n
"
/* load a01, a11 to d3 */
"vmov d7, r0, r1
\n
"
/* mov b06, b07 to d7 */
"vmla.f32 q10, q2, d1[1]
\n
"
/* out30 += a30 * b0l */
"pld [%[a_ptr], #0x100]
\n
"
/* pre load apanel */
"vmla.f32 q12, q2, d2[0]
\n
"
/* out40 += a40 * b0l */
"vmla.f32 q14, q2, d2[1]
\n
"
/* out50 += a50 * b0l */
"vldr d4, [%[b_ptr], #0x20]
\n
"
/* load b10, b11 to d4 */
"vmla.f32 q5, q3, d0[0]
\n
"
/* out01 += a00 * b0h */
"ldr r0, [%[b_ptr], #0x28]
\n
"
/* load b12 to r0 */
"vmla.f32 q7, q3, d0[1]
\n
"
/* out11 += a10 * b0h */
"ldr r1, [%[b_ptr], #0x2C]
\n
"
/* load b13 to r1 */
"vmla.f32 q9, q3, d1[0]
\n
"
/* out21 += a20 * b0h */
"vldr d0, [%[a_ptr], #0x20]
\n
"
/* load a21, a31 to d0 */
"vmov d5, r0, r1
\n
"
/* mov b12, b13 to d5 */
"vmla.f32 q11, q3, d1[1]
\n
"
/* out31 += a30 * b0h */
"ldr r0, [%[a_ptr], #0x28]
\n
"
/* load a41 to r0 */
"vmla.f32 q13, q3, d2[0]
\n
"
/* out41 += a40 * b0h */
"ldr r1, [%[a_ptr], #0x2C]
\n
"
/* load a51 to r1 */
"vmla.f32 q15, q3, d2[1]
\n
"
/* out51 += a50 * b0h */
/* Unroll 1 */
"vldr d6, [%[b_ptr], #0x30]
\n
"
/* load b14, b15 to d6 */
"vmov d1, r0, r1
\n
"
/* mov a41, a51 to d1 */
"vmla.f32 q4, q2, d3[0]
\n
"
/* out00 += a01 * b1l */
"ldr r0, [%[b_ptr], #0x38]
\n
"
/* load b16 to r0 */
"vmla.f32 q6, q2, d3[1]
\n
"
/* out10 += a11 * b1l */
"ldr r1, [%[b_ptr], #0x3C]
\n
"
/* load b17 to r1 */
"vmla.f32 q8, q2, d0[0]
\n
"
/* out20 += a21 * b1l */
"vldr d2, [%[a_ptr], #0x30]
\n
"
/* load a02, a12 to d0 */
"vmov d7, r0, r1
\n
"
/* mov b16, b17 to d7 */
"vmla.f32 q10, q2, d0[1]
\n
"
/* out30 += a31 * b1l */
"pld [%[b_ptr], #0x100]
\n
"
/* pre load apanel */
"vmla.f32 q12, q2, d1[0]
\n
"
/* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1]
\n
"
/* out50 += a51 * b1l */
"vldr d4, [%[b_ptr], #0x40]
\n
"
/* load b20, b21 to d4 */
"vmla.f32 q5, q3, d3[0]
\n
"
/* out01 += a01 * b1h */
"ldr r0, [%[b_ptr], #0x48]
\n
"
/* load b22 to r0 */
"vmla.f32 q7, q3, d3[1]
\n
"
/* out11 += a11 * b1h */
"ldr r1, [%[b_ptr], #0x4C]
\n
"
/* load b23 to r1 */
"vmla.f32 q9, q3, d0[0]
\n
"
/* out21 += a21 * b1h */
"vldr d3, [%[a_ptr], #0x38]
\n
"
/* load a22, a32 to d3 */
"vmov d5, r0, r1
\n
"
/* mov b22, b23 to d5 */
"vmla.f32 q11, q3, d0[1]
\n
"
/* out31 += a31 * b1h */
"ldr r0, [%[a_ptr], #0x40]
\n
"
/* load a42 to r0 */
"vmla.f32 q13, q3, d1[0]
\n
"
/* out41 += a41 * b1h */
"ldr r1, [%[a_ptr], #0x44]
\n
"
/* load a52 to r1 */
"vmla.f32 q15, q3, d1[1]
\n
"
/* out51 += a51 * b1h */
/* Unroll 2 */
"vldr d6, [%[b_ptr], #0x50]
\n
"
/* load b24, b25 to d6 */
"vmov d0, r0, r1
\n
"
/* mov a42, a52 to d0 */
"vmla.f32 q4, q2, d2[0]
\n
"
/* out00 += a02 * b2l */
"ldr r0, [%[b_ptr], #0x58]
\n
"
/* load b26 to r0 */
"vmla.f32 q6, q2, d2[1]
\n
"
/* out10 += a12 * b2l */
"ldr r1, [%[b_ptr], #0x5C]
\n
"
/* load b27 to r1 */
"vmla.f32 q8, q2, d3[0]
\n
"
/* out20 += a22 * b2l */
"vldr d1, [%[a_ptr], #0x48]
\n
"
/* load a03, a13 to d1 */
"vmov d7, r0, r1
\n
"
/* mov b26, b27 to d7 */
"vmla.f32 q10, q2, d3[1]
\n
"
/* out30 += a32 * b2l */
"pld [%[a_ptr], #0x140]
\n
"
/* pre load apanel */
"vmla.f32 q12, q2, d0[0]
\n
"
/* out40 += a42 * b2l */
"vmla.f32 q14, q2, d0[1]
\n
"
/* out50 += a52 * b2l */
"vldr d4, [%[b_ptr], #0x60]
\n
"
/* load b30, b31 to d4 */
"vmla.f32 q5, q3, d2[0]
\n
"
/* out01 += a02 * b2h */
"ldr r0, [%[b_ptr], #0x68]
\n
"
/* load b32 to r0 */
"vmla.f32 q7, q3, d2[1]
\n
"
/* out11 += a12 * b2h */
"ldr r1, [%[b_ptr], #0x6C]
\n
"
/* load b33 to r1 */
"vmla.f32 q9, q3, d3[0]
\n
"
/* out21 += a22 * b2h */
"vldr d2, [%[a_ptr], #0x50]
\n
"
/* load a23, a33 to d2 */
"vmov d5, r0, r1
\n
"
/* mov b32, b33 to d5 */
"vmla.f32 q11, q3, d3[1]
\n
"
/* out31 += a32 * b2h */
"ldr r0, [%[a_ptr], #0x58]
\n
"
/* load a43 to r0 */
"vmla.f32 q13, q3, d0[0]
\n
"
/* out41 += a42 * b2h */
"ldr r1, [%[a_ptr], #0x5C]
\n
"
/* load a53 to r1 */
"vmla.f32 q15, q3, d0[1]
\n
"
/* out51 += a52 * b2h */
"add %[a_ptr], %[a_ptr], #0x60
\n
"
/* aptr += 96 */
/* Unroll 3 */
"vldr d6, [%[b_ptr], #0x70]
\n
"
/* load b34, b35 to d6 */
"vmov d3, r0, r1
\n
"
/* mov a43, a53 to d3 */
"vmla.f32 q4, q2, d1[0]
\n
"
/* out00 += a03 * b3l */
"ldr r0, [%[b_ptr], #0x78]
\n
"
/* load b36 to r0 */
"vmla.f32 q6, q2, d1[1]
\n
"
/* out10 += a13 * b3l */
"ldr r1, [%[b_ptr], #0x7C]
\n
"
/* load b37 to r1 */
"vmla.f32 q8, q2, d2[0]
\n
"
/* out20 += a23 * b3l */
"add %[b_ptr], %[b_ptr], #0x80
\n
"
/* bptr += 108 */
"vldr d0, [%[a_ptr], #0x00]
\n
"
/* load a00, a10 to d0 */
"vmov d7, r0, r1
\n
"
/* mov b36, b37 to d7 */
"vmla.f32 q10, q2, d2[1]
\n
"
/* out30 += a33 * b3l */
"pld [%[b_ptr], #0xC0]
\n
"
/* pre load bpanel */
"vmla.f32 q12, q2, d3[0]
\n
"
/* out40 += a43 * b3l */
"vmla.f32 q14, q2, d3[1]
\n
"
/* out50 += a53 * b3l */
"vldr d4, [%[b_ptr], #0x00]
\n
"
/* load b00, b01 to d4 */
"vmla.f32 q5, q3, d1[0]
\n
"
/* out01 += a03 * b3h */
"ldr r0, [%[b_ptr], #0x08]
\n
"
/* load b02 to r0 */
"vmla.f32 q7, q3, d1[1]
\n
"
/* out11 += a13 * b3h */
"ldr r1, [%[b_ptr], #0x0C]
\n
"
/* load b03 to r1 */
"vmla.f32 q9, q3, d2[0]
\n
"
/* out21 += a23 * b3h */
"subs %[k], %[k], #1
\n
"
/* loop k -= 1 */
"vldr d1, [%[a_ptr], #0x08]
\n
"
/* load a20, a30 to d1 */
"vmov d5, r0, r1
\n
"
/* mov b02, b03 to d5 */
"vmla.f32 q11, q3, d2[1]
\n
"
/* out31 += a33 * b3h */
"ldr r0, [%[a_ptr], #0x10]
\n
"
/* load a40 to r0 */
"vmla.f32 q13, q3, d3[0]
\n
"
/* out41 += a43 * b3h */
"ldr r1, [%[a_ptr], #0x14]
\n
"
/* load a50 to r1 */
"vmla.f32 q15, q3, d3[1]
\n
"
/* out51 += a53 * b3h */
"bne 1b
\n
"
/* branch to k loop */
"6:
\n
"
"sub %[tails], %[tails], #4
\n
"
/* tail -= 4 */
"cmp %[tails], #4
\n
"
/* cmp tail with 4 */
"blt 3f
\n
"
/* branch to tail == 1 */
/* Tail Unroll 0 */
"vmov d2, r0, r1
\n
"
/* mov b02, b03 to d2 */
"add %[a_ptr], %[a_ptr], #0x18
\n
"
/* aptr += 24 */
"vmla.f32 q4, q2, d0[0]
\n
"
/* out00 += a00 * b0l */
"vld1.32 {d3}, [%[a_ptr] :64]!
\n
"
/* load a01, a11 to d3 */
"vmla.f32 q6, q2, d0[1]
\n
"
/* out10 += a10 * b0l */
"add %[b_ptr], %[b_ptr], #0x10
\n
"
/* bptr += 16 */
"vmla.f32 q8, q2, d1[0]
\n
"
/* out20 += a20 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]!
\n
"
/* load b04-b07 to d6,d7 */
"vmla.f32 q10, q2, d1[1]
\n
"
/* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0]
\n
"
/* out40 += a40 * b0l */
"sub %[tails], %[tails], #4
\n
"
/* tail -= 4 */
"vmla.f32 q14, q2, d2[1]
\n
"
/* out50 += a50 * b0l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]!
\n
"
/* load b10-b13 to d4,d5 */
"vmla.f32 q5, q3, d0[0]
\n
"
/* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1]
\n
"
/* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0]
\n
"
/* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1]
\n
"
/* out31 += a30 * b0h */
"vld1.32 {d0-d1}, [%[a_ptr] :64]!
\n
"
/* load a21-a51 to d0,d1 */
"cmp %[tails], #4
\n
"
/* cmp tail with 4 */
"vmla.f32 q13, q3, d2[0]
\n
"
/* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1]
\n
"
/* out51 += a50 * b0h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]!
\n
"
/* load b14-b17 to d6,d7 */
"blt 4f
\n
"
/* branch to tail == 2 */
/* Tail Unroll 1 */
"vmla.f32 q4, q2, d3[0]
\n
"
/* out00 += a01 * b1l */
"vmla.f32 q6, q2, d3[1]
\n
"
/* out10 += a11 * b1l */
"sub %[tails], %[tails], #4
\n
"
/* tail -= 4 */
"vmla.f32 q8, q2, d0[0]
\n
"
/* out20 += a21 * b1l */
"vmla.f32 q10, q2, d0[1]
\n
"
/* out30 += a31 * b1l */
"vmla.f32 q12, q2, d1[0]
\n
"
/* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1]
\n
"
/* out50 += a51 * b1l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]!
\n
"
/* load b20-b23 to d4,d5 */
"vmla.f32 q5, q3, d3[0]
\n
"
/* out01 += a01 * b1h */
"vmla.f32 q7, q3, d3[1]
\n
"
/* out11 += a11 * b1h */
"cmp %[tails], #4
\n
"
/* cmp tail with 4 */
"vld1.32 {d2-d3}, [%[a_ptr] :64]!
\n
"
/* load a02-a32 to d2,d3 */
"vmla.f32 q9, q3, d0[0]
\n
"
/* out21 += a21 * b1h */
"vmla.f32 q11, q3, d0[1]
\n
"
/* out31 += a31 * b1h */
"vmla.f32 q13, q3, d1[0]
\n
"
/* out41 += a41 * b1h */
"vmla.f32 q15, q3, d1[1]
\n
"
/* out51 += a51 * b1h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]!
\n
"
/* load b24-b27 to d6,d7 */
"blt 5f
\n
"
/* branch to tail == 3 */
/* Tail Unroll 2 */
"sub %[tails], %[tails], #4
\n
"
/* tail -= 4 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]!
\n
"
/* a42a52a03a13 to d0,d1 */
"vmla.f32 q4, q2, d2[0]
\n
"
/* out00 += a02 * b2l */
"vmla.f32 q6, q2, d2[1]
\n
"
/* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0]
\n
"
/* out20 += a22 * b2l */
"vmla.f32 q10, q2, d3[1]
\n
"
/* out30 += a32 * b2l */
"vmla.f32 q12, q2, d0[0]
\n
"
/* out40 += a42 * b2l */
"vmla.f32 q14, q2, d0[1]
\n
"
/* out50 += a52 * b2l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]!
\n
"
/* load b30-b33 to d4,d5 */
"vmla.f32 q5, q3, d2[0]
\n
"
/* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1]
\n
"
/* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0]
\n
"
/* out21 += a22 * b2h */
"vmla.f32 q11, q3, d3[1]
\n
"
/* out31 += a32 * b2h */
"vld1.32 {d2-d3}, [%[a_ptr] :64]!
\n
"
/* load a23-a53 to d2,d3 */
"vmla.f32 q13, q3, d0[0]
\n
"
/* out41 += a42 * b2h */
"vmla.f32 q15, q3, d0[1]
\n
"
/* out51 += a52 * b2h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]!
\n
"
/* load b34-b37 to d6,d7 */
/* Tail Unroll 3 */
"vmla.f32 q4, q2, d1[0]
\n
"
/* out00 += a03 * b3l */
"vmla.f32 q5, q3, d1[0]
\n
"
/* out01 += a03 * b3h */
"vmla.f32 q6, q2, d1[1]
\n
"
/* out10 += a13 * b3l */
"vmla.f32 q7, q3, d1[1]
\n
"
/* out11 += a13 * b3h */
"vmla.f32 q8, q2, d2[0]
\n
"
/* out20 += a23 * b3l */
"vmla.f32 q9, q3, d2[0]
\n
"
/* out21 += a23 * b3h */
"vmla.f32 q10, q2, d2[1]
\n
"
/* out30 += a33 * b3l */
"vmla.f32 q11, q3, d2[1]
\n
"
/* out31 += a33 * b3h */
"vmla.f32 q12, q2, d3[0]
\n
"
/* out40 += a43 * b3l */
"vmla.f32 q13, q3, d3[0]
\n
"
/* out41 += a43 * b3h */
"vmla.f32 q14, q2, d3[1]
\n
"
/* out50 += a53 * b3l */
"vmla.f32 q15, q3, d3[1]
\n
"
/* out51 += a53 * b3h */
"b 2f
\n
"
/* branch to check relu */
/* tails==1 final tail */
"3:
\n
"
"vmov d2, r0, r1
\n
"
/* mov b02, b03 to d2 */
"add %[b_ptr], %[b_ptr], #0x10
\n
"
/* bptr += 16 */
"vmla.f32 q4, q2, d0[0]
\n
"
/* out00 += a00 * b0l */
"add %[a_ptr], %[a_ptr], #0x18
\n
"
/* aptr += 24 */
"vmla.f32 q6, q2, d0[1]
\n
"
/* out10 += a10 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]!
\n
"
/* load b04-b07 to d6,d7 */
"vmla.f32 q8, q2, d1[0]
\n
"
/* out20 += a20 * b0l */
"vmla.f32 q10, q2, d1[1]
\n
"
/* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0]
\n
"
/* out40 += a40 * b0l */
"vmla.f32 q14, q2, d2[1]
\n
"
/* out50 += a50 * b0l */
"vmla.f32 q5, q3, d0[0]
\n
"
/* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1]
\n
"
/* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0]
\n
"
/* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1]
\n
"
/* out31 += a30 * b0h */
"vmla.f32 q13, q3, d2[0]
\n
"
/* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1]
\n
"
/* out51 += a50 * b0h */
"b 2f
\n
"
/* branch to check relu */
/* tails==2 final tail */
"4:
\n
"
"vmla.f32 q4, q2, d3[0]
\n
"
/* out00 += a01 * b1l */
"vmla.f32 q5, q3, d3[0]
\n
"
/* out01 += a01 * b1h */
"vmla.f32 q6, q2, d3[1]
\n
"
/* out10 += a11 * b1l */
"vmla.f32 q7, q3, d3[1]
\n
"
/* out11 += a11 * b1h */
"vmla.f32 q8, q2, d0[0]
\n
"
/* out20 += a21 * b1l */
"vmla.f32 q9, q3, d0[0]
\n
"
/* out21 += a21 * b1h */
"vmla.f32 q10, q2, d0[1]
\n
"
/* out30 += a31 * b1l */
"vmla.f32 q11, q3, d0[1]
\n
"
/* out31 += a31 * b1h */
"vmla.f32 q12, q2, d1[0]
\n
"
/* out40 += a41 * b1l */
"vmla.f32 q13, q3, d1[0]
\n
"
/* out41 += a41 * b1h */
"vmla.f32 q14, q2, d1[1]
\n
"
/* out50 += a51 * b1l */
"vmla.f32 q15, q3, d1[1]
\n
"
/* out51 += a51 * b1h */
"b 2f
\n
"
/* branch to check relu */
/* tails==3 final tail */
"5:
\n
"
"vmla.f32 q4, q2, d2[0]
\n
"
/* out00 += a02 * b2l */
"vld1.32 {d0}, [%[a_ptr] :64]!
\n
"
/* load a42, a52 to d0 */
"vmla.f32 q6, q2, d2[1]
\n
"
/* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0]
\n
"
/* out20 += a22 * b2l */
"vmla.f32 q5, q3, d2[0]
\n
"
/* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1]
\n
"
/* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0]
\n
"
/* out21 += a22 * b2h */
"vmla.f32 q10, q2, d3[1]
\n
"
/* out30 += a32 * b2l */
"vmla.f32 q11, q3, d3[1]
\n
"
/* out31 += a32 * b2h */
"vmla.f32 q12, q2, d0[0]
\n
"
/* out40 += a42 * b2l */
"vmla.f32 q13, q3, d0[0]
\n
"
/* out41 += a42 * b2h */
"vmla.f32 q14, q2, d0[1]
\n
"
/* out50 += a52 * b2l */
"vmla.f32 q15, q3, d0[1]
\n
"
/* out51 += a52 * b2h */
/* relu */
"2:
\n
"
"cmp %[tails], #1
\n
"
/* cmp tail is relu */
"bne 0f
\n
"
/* no relu branch to end */
"vmov.i32 q0, #0
\n
"
/* mov 0.f to q0 */
"vmax.f32 q4, q4, q0
\n
"
/* out00 relu */
"vmax.f32 q5, q5, q0
\n
"
/* out01 relu */
"vmax.f32 q6, q6, q0
\n
"
/* out10 relu */
"vmax.f32 q7, q7, q0
\n
"
/* out11 relu */
"vmax.f32 q8, q8, q0
\n
"
/* out20 relu */
"vmax.f32 q9, q9, q0
\n
"
/* out21 relu */
"vmax.f32 q10, q10, q0
\n
"
/* out30 relu */
"vmax.f32 q11, q11, q0
\n
"
/* out31 relu */
"vmax.f32 q12, q12, q0
\n
"
/* out40 relu */
"vmax.f32 q13, q13, q0
\n
"
/* out41 relu */
"vmax.f32 q14, q14, q0
\n
"
/* out50 relu */
"vmax.f32 q15, q15, q0
\n
"
/* out51 relu */
"0:
\n
"
"vst1.32 {d8-d11}, [%[c_ptr0]]!
\n
"
/* store out0 to cptr0 */
"vst1.32 {d12-d15}, [%[c_ptr1]]!
\n
"
/* store out1 to cptr1 */
"vst1.32 {d16-d19}, [%[c_ptr2]]!
\n
"
/* store out2 to cptr2 */
"vst1.32 {d20-d23}, [%[c_ptr3]]!
\n
"
/* store out3 to cptr3 */
"vst1.32 {d24-d27}, [%[c_ptr4]]!
\n
"
/* store out4 to cptr4 */
"vst1.32 {d28-d31}, [%[c_ptr5]]!
\n
"
/* store out5 to cptr5 */
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
c_ptr0
]
"+r"
(
c_ptr0
),
[
c_ptr1
]
"+r"
(
c_ptr1
),
[
c_ptr2
]
"+r"
(
c_ptr2
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr4
]
"+r"
(
c_ptr4
),
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
k
]
"+r"
(
k
),
[
tails
]
"+r"
(
tails
)
:
[
bias_ptr
]
"r"
(
bias_local
)
:
"r0"
,
"r1"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
// clang-format on
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
*
pout0
++
=
cout0
[
i
];
*
pout1
++
=
cout1
[
i
];
*
pout2
++
=
cout2
[
i
];
*
pout3
++
=
cout3
[
i
];
*
pout4
++
=
cout4
[
i
];
*
pout5
++
=
cout5
[
i
];
}
}
}
}
}
}
void
sgemm_prepacked_4x8
(
bool
is_transB
,
int
M
,
int
N
,
...
...
lite/tests/kernels/topk_compute_test.cc
浏览文件 @
3b2d3189
...
...
@@ -50,11 +50,11 @@ class TopkComputeTester : public arena::TestCase {
out_dims
[
out_dims
.
size
()
-
1
]
=
k_
;
out_val
->
Resize
(
out_dims
);
out_ind
->
Resize
(
out_dims
);
auto
*
out_val_data
=
out_val
->
mutable_data
<
T1
>
();
auto
*
out_ind_data
=
out_ind
->
mutable_data
<
T2
>
();
auto
*
out_val_data
=
out_val
->
template
mutable_data
<
T1
>();
auto
*
out_ind_data
=
out_ind
->
template
mutable_data
<
T2
>();
auto
*
x
=
scope
->
FindTensor
(
x_
);
const
auto
*
x_data
=
x
->
data
<
T1
>
();
const
auto
*
x_data
=
x
->
template
data
<
T1
>();
int
m
=
out_dims
.
production
()
/
k_
;
int
n
=
x_dims_
[
x_dims_
.
size
()
-
1
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录