Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
dde12f0d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dde12f0d
编写于
11月 20, 2019
作者:
Y
yiicy
提交者:
GitHub
11月 20, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ARM] sgemv support transA, test=develop (#2453)
* [ARM] sgemv support transA, test=develop * add sgemv ut, test=develop
上级
b094b2b6
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
739 addition
and
51 deletion
+739
-51
lite/backends/arm/math/conv_impl.cc
lite/backends/arm/math/conv_impl.cc
+4
-2
lite/backends/arm/math/sgemv.cc
lite/backends/arm/math/sgemv.cc
+529
-41
lite/backends/arm/math/sgemv.h
lite/backends/arm/math/sgemv.h
+6
-3
lite/kernels/arm/fc_compute.cc
lite/kernels/arm/fc_compute.cc
+2
-1
lite/kernels/arm/matmul_compute.cc
lite/kernels/arm/matmul_compute.cc
+1
-1
lite/kernels/arm/mul_compute.cc
lite/kernels/arm/mul_compute.cc
+2
-3
lite/tests/math/CMakeLists.txt
lite/tests/math/CMakeLists.txt
+1
-0
lite/tests/math/sgemv_compute_test.cc
lite/tests/math/sgemv_compute_test.cc
+194
-0
未找到文件。
lite/backends/arm/math/conv_impl.cc
浏览文件 @
dde12f0d
...
...
@@ -202,7 +202,8 @@ void conv1x1s1_gemm(const float* i_data,
k
,
flag_bias
,
bias_group
,
flag_relu
);
flag_relu
,
ctx
);
}
else
{
sgemm_prepack
(
false
,
m
,
...
...
@@ -395,7 +396,8 @@ void conv_im2col_gemm(const float* i_data,
k
,
flag_bias
,
bias_group
,
flag_relu
);
flag_relu
,
ctx
);
}
else
{
int
ldb
=
n
;
sgemm_prepack
(
false
,
...
...
lite/backends/arm/math/sgemv.cc
浏览文件 @
dde12f0d
...
...
@@ -14,6 +14,7 @@
#include "lite/backends/arm/math/sgemv.h"
#include <arm_neon.h>
#include <algorithm>
#include "lite/utils/cp_logging.h"
namespace
paddle
{
...
...
@@ -50,6 +51,495 @@ void sgemv_bias_relu(const bool transA,
const
float
*
x
,
float
*
y
,
const
float
*
bias
);
#ifdef __aarch64__
void
sgemv_trans
(
const
int
M
,
const
int
N
,
const
float
*
A
,
const
float
*
x
,
float
*
y
,
bool
flag_bias
,
const
float
*
bias
,
bool
flag_relu
,
const
ARMContext
*
ctx
)
{
int
m_cnt16
=
M
>>
4
;
int
m_cnt8
=
(
M
&
15
)
>>
3
;
int
m_cnt4
=
(
M
&
15
&
7
)
>>
2
;
int
m_remain
=
M
&
15
&
7
&
3
;
int
ths
=
ctx
->
threads
();
int
valid_ths
=
std
::
min
((
N
+
3
)
/
4
,
ths
);
int
valid_block
=
std
::
max
(
4
,
(
N
/
valid_ths
+
3
)
/
4
*
4
);
valid_ths
=
(
N
+
valid_block
-
1
)
/
valid_block
;
int
block_cnt
=
valid_block
/
4
;
float
zero_buf
[
M
];
// NOLINT
float
y_buf
[
valid_ths
*
M
];
// NOLINT
memset
(
zero_buf
,
0
,
M
*
sizeof
(
float
));
if
(
flag_bias
)
{
memcpy
(
y_buf
,
bias
,
M
*
sizeof
(
float
));
memset
(
y_buf
+
M
,
0
,
(
valid_ths
-
1
)
*
M
*
sizeof
(
float
));
}
else
{
memset
(
y_buf
,
0
,
valid_ths
*
M
*
sizeof
(
float
));
}
#pragma omp parallel for
for
(
int
t
=
0
;
t
<
valid_ths
;
++
t
)
{
float
*
block_y
=
y_buf
+
t
*
M
;
const
float
*
block_x
=
x
+
t
*
valid_block
;
const
float
*
block_A
=
A
+
t
*
valid_block
*
M
;
for
(
int
i
=
0
;
i
<
block_cnt
;
++
i
)
{
float
*
y_ptr
=
block_y
;
const
float
*
x_ptr
=
block_x
+
i
*
4
;
const
float
*
in0_ptr
=
block_A
+
i
*
4
*
M
;
const
float
*
in1_ptr
=
in0_ptr
+
M
;
const
float
*
in2_ptr
=
in1_ptr
+
M
;
const
float
*
in3_ptr
=
in2_ptr
+
M
;
int
offset
=
t
*
valid_block
+
(
i
+
1
)
*
4
-
N
;
if
(
offset
>
0
)
{
if
(
offset
>
3
)
{
in0_ptr
=
zero_buf
;
in1_ptr
=
zero_buf
;
in2_ptr
=
zero_buf
;
in3_ptr
=
zero_buf
;
}
else
{
switch
(
offset
)
{
case
3
:
in1_ptr
=
zero_buf
;
case
2
:
in2_ptr
=
zero_buf
;
case
1
:
in3_ptr
=
zero_buf
;
default:
break
;
}
}
}
// clang-format off
if
(
m_cnt16
>
0
)
{
int
cnt16
=
m_cnt16
;
asm
volatile
(
"ld1 {v4.4s}, [%[x]]
\n
"
/* load x to v4 */
"ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64
\n
"
/* load in0 to v5, v6, v7, v8 */
"ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64
\n
"
/* load in1 to v9, v10, v11, v12 */
"ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64
\n
"
/* load in2 to v13, v14, v15, v16 */
"ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64
\n
"
/* load in3 to v17, v18, v19, v20 */
"1:
\n
"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]]
\n
"
/*load y to v0, v1, v2, v3 */
"fmla v0.4s, v5.4s, v4.s[0]
\n
"
/* v0 += v5 * v4[0] */
"fmla v1.4s, v6.4s, v4.s[0]
\n
"
/* v1 += v6 * v4[0] */
"fmla v2.4s, v7.4s, v4.s[0]
\n
"
/* v2 += v7 * v4[0] */
"fmla v3.4s, v8.4s, v4.s[0]
\n
"
/* v3 += v8 * v4[0] */
"ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [%[in0]], #64
\n
"
/* load in0 to v5, v6, v7, v8 */
"fmla v0.4s, v9.4s, v4.s[1]
\n
"
/* v0 += v9 * v4[1] */
"fmla v1.4s, v10.4s, v4.s[1]
\n
"
/* v1 += v10 * v4[1] */
"fmla v2.4s, v11.4s, v4.s[1]
\n
"
/* v2 += v11 * v4[1] */
"fmla v3.4s, v12.4s, v4.s[1]
\n
"
/* v3 += v12 * v4[1] */
"ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[in1]], #64
\n
"
/* load in1 to v9, v10, v11, v12 */
"fmla v0.4s, v13.4s, v4.s[2]
\n
"
/* v0 += v13 * v4[2] */
"fmla v1.4s, v14.4s, v4.s[2]
\n
"
/* v1 += v14 * v4[2] */
"fmla v2.4s, v15.4s, v4.s[2]
\n
"
/* v2 += v15 * v4[2] */
"fmla v3.4s, v16.4s, v4.s[2]
\n
"
/* v3 += v16 * v4[2] */
"ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[in2]], #64
\n
"
/* load in2 to v13, v14, v15, v16 */
"fmla v0.4s, v17.4s, v4.s[3]
\n
"
/* v0 += v17 * v4[3] */
"fmla v1.4s, v18.4s, v4.s[3]
\n
"
/* v1 += v18 * v4[3] */
"fmla v2.4s, v19.4s, v4.s[3]
\n
"
/* v2 += v19 * v4[3] */
"fmla v3.4s, v20.4s, v4.s[3]
\n
"
/* v3 += v20 * v4[3] */
"ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [%[in3]], #64
\n
"
/* load in3 to v17, v18, v19, v20 */
"subs %w[cnt], %w[cnt], #1
\n
"
/* sub cnt */
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[y]], #64
\n
"
/* store v0, v1, v2, v3 to y */
"bne 1b
\n
"
/* branch to label 1 */
"sub %[in0], %[in0], #64
\n
"
/* restore in0 address */
"sub %[in1], %[in1], #64
\n
"
/* restore in1 address */
"sub %[in2], %[in2], #64
\n
"
/* restore in2 address */
"sub %[in3], %[in3], #64
\n
"
/* restore in3 address */
:
[
cnt
]
"+r"
(
cnt16
),
[
in0
]
"+r"
(
in0_ptr
),
[
in1
]
"+r"
(
in1_ptr
),
[
in2
]
"+r"
(
in2_ptr
),
[
in3
]
"+r"
(
in3_ptr
),
[
y
]
"+r"
(
y_ptr
)
:
[
x
]
"r"
(
x_ptr
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"cc"
,
"memory"
);
}
if
(
m_cnt8
>
0
)
{
int
cnt8
=
m_cnt8
;
asm
volatile
(
"ld1 {v2.4s}, [%[x]]
\n
"
/* load x to v2 */
"ld1 {v3.4s, v4.4s}, [%[in0]], #32
\n
"
/* load in0 to v3, v4 */
"ld1 {v5.4s, v6.4s}, [%[in1]], #32
\n
"
/* load in1 to v5, v6 */
"ld1 {v7.4s, v8.4s}, [%[in2]], #32
\n
"
/* load in2 to v7, v8 */
"ld1 {v9.4s, v10.4s}, [%[in3]], #32
\n
"
/* load in3 to v9, v10*/
"1:
\n
"
"ld1 {v0.4s, v1.4s}, [%[y]]
\n
"
/* load y to v0, v1 */
"fmla v0.4s, v3.4s, v2.s[0]
\n
"
/* v0 += v3 * v2[0] */
"fmla v1.4s, v4.4s, v2.s[0]
\n
"
/* v1 += v4 * v2[0] */
"prfm pldl1keep, [%[in0]]
\n
"
/* preload in0 */
"ld1 {v3.4s, v4.4s}, [%[in0]], #32
\n
"
/* load in0 to v3, v4 */
"fmla v0.4s, v5.4s, v2.s[1]
\n
"
/* v0 += v5 * v2[1] */
"fmla v1.4s, v6.4s, v2.s[1]
\n
"
/* v1 += v6 * v2[1] */
"prfm pldl1keep, [%[in1]]
\n
"
/* preload in1 */
"ld1 {v5.4s, v6.4s}, [%[in1]], #32
\n
"
/* load in0 to v5, v6 */
"fmla v0.4s, v7.4s, v2.s[2]
\n
"
/* v0 += v7 * v2[2] */
"fmla v1.4s, v8.4s, v2.s[2]
\n
"
/* v1 += v8 * v2[2] */
"prfm pldl1keep, [%[in2]]
\n
"
/* preload in2 */
"ld1 {v7.4s, v8.4s}, [%[in2]], #32
\n
"
/* load in0 to v7, v8 */
"fmla v0.4s, v9.4s, v2.s[3]
\n
"
/* v0 += v9 * v2[3] */
"fmla v1.4s, v10.4s, v2.s[3]
\n
"
/* v1 += v10 * v2[3] */
"subs %w[cnt], %w[cnt], #1
\n
"
/* sub cnt */
"prfm pldl1keep, [%[in3]]
\n
"
/* preload in3 */
"st1 {v0.4s, v1.4s}, [%[y]], #32
\n
"
/* store v0, v1 to y */
"ld1 {v9.4s, v10.4s},[%[in3]], #32
\n
"
/* load in0 to v9, v10*/
"bne 1b
\n
"
/* branch to label 1 */
"sub %[in0], %[in0], #32
\n
"
/* restore in0 address */
"sub %[in1], %[in1], #32
\n
"
/* restore in1 address */
"sub %[in2], %[in2], #32
\n
"
/* restore in2 address */
"sub %[in3], %[in3], #32
\n
"
/* restore in3 address */
:
[
cnt
]
"+r"
(
cnt8
),
[
in0
]
"+r"
(
in0_ptr
),
[
in1
]
"+r"
(
in1_ptr
),
[
in2
]
"+r"
(
in2_ptr
),
[
in3
]
"+r"
(
in3_ptr
),
[
y
]
"+r"
(
y_ptr
)
:
[
x
]
"r"
(
x_ptr
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"cc"
,
"memory"
);
}
if
(
m_cnt4
>
0
)
{
int
cnt4
=
m_cnt4
;
asm
volatile
(
"ld1 {v1.4s}, [%[in0]], #16
\n
"
/* load in0 to v1 */
"ld1 {v2.4s}, [%[in1]], #16
\n
"
/* load in1 to v2 */
"ld1 {v3.4s}, [%[in2]], #16
\n
"
/* load in2 to v3 */
"ld1 {v4.4s}, [%[in3]], #16
\n
"
/* load in3 to v4 */
"ld1 {v5.4s}, [%[x]]
\n
"
/* load x to v5 */
"1:
\n
"
"ld1 {v0.4s}, [%[y]]
\n
"
/* load y to v0 */
"fmla v0.4s, v1.4s, v5.s[0]
\n
"
/* v0 += v1 * v5[0] */
"prfm pldl1keep, [%[in0]]
\n
"
/* preload in0 */
"ld1 {v1.4s}, [%[in0]], #16
\n
"
/* load in0 to v1 */
"fmla v0.4s, v2.4s, v5.s[1]
\n
"
/* v0 += v2 * v5[1] */
"prfm pldl1keep, [%[in1]]
\n
"
/* preload in1 */
"ld1 {v2.4s}, [%[in1]], #16
\n
"
/* load in1 to v2 */
"fmla v0.4s, v3.4s, v5.s[2]
\n
"
/* v0 += v3 * v5[2] */
"prfm pldl1keep, [%[in2]]
\n
"
/* preload in2 */
"ld1 {v3.4s}, [%[in2]], #16
\n
"
/* load in2 to v3 */
"fmla v0.4s, v4.4s, v5.s[3]
\n
"
/* v0 += v4 * v5[3] */
"subs %w[cnt], %w[cnt], #1
\n
"
/* sub cnt */
"prfm pldl1keep, [%[in3]]
\n
"
/* preload in3 */
"st1 {v0.4s}, [%[y]], #16
\n
"
/* store v0 to y */
"ld1 {v4.4s}, [%[in3]], #16
\n
"
/* load in3 to v4 */
"bne 1b
\n
"
/* branch to label 1 */
"sub %[in0], %[in0], #16
\n
"
/* restore in0 address*/
"sub %[in1], %[in1], #16
\n
"
/* restore in1 address*/
"sub %[in2], %[in2], #16
\n
"
/* restore in2 address*/
"sub %[in3], %[in3], #16
\n
"
/* restore in3 address*/
:
[
cnt
]
"+r"
(
cnt4
),
[
in0
]
"+r"
(
in0_ptr
),
[
in1
]
"+r"
(
in1_ptr
),
[
in2
]
"+r"
(
in2_ptr
),
[
in3
]
"+r"
(
in3_ptr
),
[
y
]
"+r"
(
y_ptr
)
:
[
x
]
"r"
(
x_ptr
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"cc"
,
"memory"
);
}
// clang-format on
for
(
int
r
=
0
;
r
<
m_remain
;
++
r
)
{
float
val0
=
x_ptr
[
0
]
*
in0_ptr
[
r
];
float
val1
=
x_ptr
[
1
]
*
in1_ptr
[
r
];
float
val2
=
x_ptr
[
2
]
*
in2_ptr
[
r
];
float
val3
=
x_ptr
[
3
]
*
in3_ptr
[
r
];
y_ptr
[
r
]
+=
val0
+
val1
+
val2
+
val3
;
}
}
}
int
cnt4
=
M
>>
2
;
int
remain
=
M
&
3
;
//! do reduction
int
rdc_ths
=
valid_ths
>>
1
;
while
(
rdc_ths
>
0
)
{
#pragma omp parallel for
for
(
int
t
=
0
;
t
<
rdc_ths
;
++
t
)
{
float
*
y0
=
y_buf
+
t
*
M
;
for
(
int
i
=
t
+
rdc_ths
;
i
<
valid_ths
;
i
+=
rdc_ths
)
{
float
*
y0_ptr
=
y0
;
float
*
y_ptr
=
y_buf
+
i
*
M
;
for
(
int
j
=
0
;
j
<
cnt4
;
++
j
)
{
float32x4_t
val0
=
vld1q_f32
(
y0_ptr
+
j
*
4
);
float32x4_t
val1
=
vld1q_f32
(
y_ptr
+
j
*
4
);
float32x4_t
val
=
vaddq_f32
(
val0
,
val1
);
vst1q_f32
(
y0_ptr
+
j
*
4
,
val
);
}
y0_ptr
+=
cnt4
*
4
;
y_ptr
+=
cnt4
*
4
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
y0_ptr
[
j
]
+=
y_ptr
[
j
];
}
}
}
valid_ths
=
rdc_ths
;
rdc_ths
=
rdc_ths
>>
1
;
}
if
(
flag_relu
)
{
float
*
in_y
=
y_buf
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
if
(
cnt4
>
0
)
{
int
cnt
=
cnt4
;
asm
volatile
(
"ld1 {v0.4s}, [%[in_y]], #16
\n
"
/* load y to v0 */
"1:
\n
"
"fmax v1.4s, v0.4s, %[vzero].4s
\n
"
/* v0 relu */
"ld1 {v0.4s}, [%[in_y]], #16
\n
"
/* load y to v0 */
"subs %w[cnt], %w[cnt], #1
\n
"
/* sub cnt */
"st1 {v1.4s}, [%[out_y]], #16
\n
"
/* store v1 to y */
"bne 1b
\n
"
/* branch to label 1*/
"sub %[in_y], %[in_y], #16
\n
"
/* restore in_y */
:
[
cnt
]
"+r"
(
cnt
),
[
in_y
]
"+r"
(
in_y
),
[
out_y
]
"+r"
(
y
)
:
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v1"
,
"cc"
,
"memory"
);
}
for
(
int
r
=
0
;
r
<
remain
;
++
r
)
{
y
[
r
]
=
in_y
[
r
]
>
0.
f
?
in_y
[
r
]
:
0.
f
;
}
}
else
{
memcpy
(
y
,
y_buf
,
M
*
sizeof
(
float
));
}
}
#else
void
sgemv_trans
(
const
int
M
,
const
int
N
,
const
float
*
A
,
const
float
*
x
,
float
*
y
,
bool
flag_bias
,
const
float
*
bias
,
bool
flag_relu
,
const
ARMContext
*
ctx
)
{
int
m_cnt8
=
M
>>
3
;
int
m_cnt4
=
(
M
&
7
)
>>
2
;
int
m_remain
=
M
&
7
&
3
;
int
ths
=
ctx
->
threads
();
int
valid_ths
=
std
::
min
((
N
+
3
)
/
4
,
ths
);
int
valid_block
=
std
::
max
(
4
,
(
N
/
valid_ths
+
3
)
/
4
*
4
);
valid_ths
=
(
N
+
valid_block
-
1
)
/
valid_block
;
int
block_cnt
=
valid_block
/
4
;
float
zero_buf
[
M
];
// NOLINT
float
y_buf
[
valid_ths
*
M
];
// NOLINT
memset
(
zero_buf
,
0
,
M
*
sizeof
(
float
));
if
(
flag_bias
)
{
memcpy
(
y_buf
,
bias
,
M
*
sizeof
(
float
));
memset
(
y_buf
+
M
,
0
,
(
valid_ths
-
1
)
*
M
*
sizeof
(
float
));
}
else
{
memset
(
y_buf
,
0
,
valid_ths
*
M
*
sizeof
(
float
));
}
#pragma omp parallel for
for
(
int
t
=
0
;
t
<
valid_ths
;
++
t
)
{
float
*
block_y
=
y_buf
+
t
*
M
;
const
float
*
block_x
=
x
+
t
*
valid_block
;
const
float
*
block_A
=
A
+
t
*
valid_block
*
M
;
for
(
int
i
=
0
;
i
<
block_cnt
;
++
i
)
{
float
*
y_ptr
=
block_y
;
const
float
*
x_ptr
=
block_x
+
i
*
4
;
const
float
*
in0_ptr
=
block_A
+
i
*
4
*
M
;
const
float
*
in1_ptr
=
in0_ptr
+
M
;
const
float
*
in2_ptr
=
in1_ptr
+
M
;
const
float
*
in3_ptr
=
in2_ptr
+
M
;
int
offset
=
t
*
valid_block
+
(
i
+
1
)
*
4
-
N
;
if
(
offset
>
0
)
{
if
(
offset
>
3
)
{
in0_ptr
=
zero_buf
;
in1_ptr
=
zero_buf
;
in2_ptr
=
zero_buf
;
in3_ptr
=
zero_buf
;
}
else
{
switch
(
offset
)
{
case
3
:
in1_ptr
=
zero_buf
;
case
2
:
in2_ptr
=
zero_buf
;
case
1
:
in3_ptr
=
zero_buf
;
default:
break
;
}
}
}
// clang-format off
if
(
m_cnt8
>
0
)
{
int
cnt8
=
m_cnt8
;
asm
volatile
(
"vld1.32 {d4-d5}, [%[x]]
\n
"
/* load x to q2 */
"vld1.32 {d6-d9}, [%[in0]]!
\n
"
/* load in0 to q3, q4 */
"vld1.32 {d10-d13},[%[in1]]!
\n
"
/* load in1 to q5, q6 */
"vld1.32 {d14-d17},[%[in2]]!
\n
"
/* load in2 to q7, q8 */
"vld1.32 {d18-d21},[%[in3]]!
\n
"
/* load in3 to q9, q10*/
"1:
\n
"
"vld1.32 {d0-d3}, [%[y]]
\n
"
/* load y to q0, q1 */
"vmla.f32 q0, q3, d4[0]
\n
"
/* q0 += q3 * q2[0] */
"vmla.f32 q1, q4, d4[0]
\n
"
/* q1 += q4 * q2[0] */
"pld [%[in0]]
\n
"
/* preload in0 */
"vld1.32 {d6-d9}, [%[in0]]!
\n
"
/* load in0 to q3, q4 */
"vmla.f32 q0, q5, d4[1]
\n
"
/* q0 += q5 * q2[1] */
"vmla.f32 q1, q6, d4[1]
\n
"
/* q1 += q6 * q2[1] */
"pld [%[in1]]
\n
"
/* preload in1 */
"vld1.32 {d10-d13},[%[in1]]!
\n
"
/* load in0 to q5, q6 */
"vmla.f32 q0, q7, d5[0]
\n
"
/* q0 += q7 * q2[2] */
"vmla.f32 q1, q8, d5[0]
\n
"
/* q1 += q8 * q2[2] */
"pld [%[in2]]
\n
"
/* preload in2 */
"vld1.32 {d14-d17},[%[in2]]!
\n
"
/* load in0 to q7, q8 */
"vmla.f32 q0, q9, d5[1]
\n
"
/* q0 += q9 * q2[3] */
"vmla.f32 q1, q10, d5[1]
\n
"
/* q1 += q10 * q2[3] */
"subs %[cnt], %[cnt], #1
\n
"
/* sub cnt */
"pld [%[in3]]
\n
"
/* preload in3 */
"vst1.32 {d0-d3}, [%[y]]!
\n
"
/* store q0, q1 to y */
"vld1.32 {d18-d21},[%[in3]]!
\n
"
/* load in0 to q9, q10*/
"pld [%[y], #32]
\n
"
/* preload y */
"bne 1b
\n
"
/* branch to label 1 */
"sub %[in0], %[in0], #32
\n
"
/* restore in0 address */
"sub %[in1], %[in1], #32
\n
"
/* restore in1 address */
"sub %[in2], %[in2], #32
\n
"
/* restore in2 address */
"sub %[in3], %[in3], #32
\n
"
/* restore in3 address */
:
[
cnt
]
"+r"
(
cnt8
),
[
in0
]
"+r"
(
in0_ptr
),
[
in1
]
"+r"
(
in1_ptr
),
[
in2
]
"+r"
(
in2_ptr
),
[
in3
]
"+r"
(
in3_ptr
),
[
y
]
"+r"
(
y_ptr
)
:
[
x
]
"r"
(
x_ptr
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"cc"
,
"memory"
);
}
if
(
m_cnt4
>
0
)
{
int
cnt4
=
m_cnt4
;
asm
volatile
(
"vld1.32 {d2-d3}, [%[in0]]!
\n
"
/* load in0 to q1 */
"vld1.32 {d4-d5}, [%[in1]]!
\n
"
/* load in1 to q2 */
"vld1.32 {d6-d7}, [%[in2]]!
\n
"
/* load in2 to q3 */
"vld1.32 {d8-d9}, [%[in3]]!
\n
"
/* load in3 to q4 */
"vld1.32 {d10-d11},[%[x]]
\n
"
/* load x to q5 */
"1:
\n
"
"vld1.32 {d0-d1}, [%[y]]
\n
"
/* load y to q0 */
"vmla.f32 q0, q1, d10[0]
\n
"
/* q0 += q1 * q5[0] */
"pld [%[in0]]
\n
"
/* preload in0 */
"vld1.32 {d2-d3}, [%[in0]]!
\n
"
/* load in0 to q1 */
"vmla.f32 q0, q2, d10[1]
\n
"
/* q0 += q2 * q5[1] */
"pld [%[in1]]
\n
"
/* preload in1 */
"vld1.32 {d4-d5}, [%[in1]]!
\n
"
/* load in0 to q2 */
"vmla.f32 q0, q3, d11[0]
\n
"
/* q0 += q3 * q5[2] */
"pld [%[in2]]
\n
"
/* preload in2 */
"vld1.32 {d6-d7}, [%[in2]]!
\n
"
/* load in0 to q3 */
"vmla.f32 q0, q4, d11[1]
\n
"
/* q0 += q4 * q5[3] */
"subs %[cnt], %[cnt], #1
\n
"
/* sub cnt */
"pld [%[in3]]
\n
"
/* preload in3 */
"vst1.32 {d0-d1}, [%[y]]!
\n
"
/* store q0 to y */
"vld1.32 {d8-d9}, [%[in3]]!
\n
"
/* load in0 to q4 */
"bne 1b
\n
"
/* branch to label 1 */
"sub %[in0], %[in0], #16
\n
"
/* restore in0 address*/
"sub %[in1], %[in1], #16
\n
"
/* restore in1 address*/
"sub %[in2], %[in2], #16
\n
"
/* restore in2 address*/
"sub %[in3], %[in3], #16
\n
"
/* restore in3 address*/
:
[
cnt
]
"+r"
(
cnt4
),
[
in0
]
"+r"
(
in0_ptr
),
[
in1
]
"+r"
(
in1_ptr
),
[
in2
]
"+r"
(
in2_ptr
),
[
in3
]
"+r"
(
in3_ptr
),
[
y
]
"+r"
(
y_ptr
)
:
[
x
]
"r"
(
x_ptr
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"cc"
,
"memory"
);
}
// clang-format on
for
(
int
r
=
0
;
r
<
m_remain
;
++
r
)
{
float
val0
=
x_ptr
[
0
]
*
in0_ptr
[
r
];
float
val1
=
x_ptr
[
1
]
*
in1_ptr
[
r
];
float
val2
=
x_ptr
[
2
]
*
in2_ptr
[
r
];
float
val3
=
x_ptr
[
3
]
*
in3_ptr
[
r
];
y_ptr
[
r
]
+=
val0
+
val1
+
val2
+
val3
;
}
}
}
//! do reduction
int
rdc_ths
=
valid_ths
>>
1
;
while
(
rdc_ths
>
0
)
{
#pragma omp parallel for
for
(
int
t
=
0
;
t
<
rdc_ths
;
++
t
)
{
float
*
y0
=
y_buf
+
t
*
M
;
for
(
int
i
=
t
+
rdc_ths
;
i
<
valid_ths
;
i
+=
rdc_ths
)
{
float
*
y0_ptr
=
y0
;
float
*
y_ptr
=
y_buf
+
i
*
M
;
for
(
int
j
=
0
;
j
<
m_cnt8
;
++
j
)
{
float32x4_t
val00
=
vld1q_f32
(
y0_ptr
+
j
*
8
);
float32x4_t
val01
=
vld1q_f32
(
y0_ptr
+
j
*
8
+
4
);
float32x4_t
val10
=
vld1q_f32
(
y_ptr
+
j
*
8
);
float32x4_t
val11
=
vld1q_f32
(
y_ptr
+
j
*
8
+
4
);
float32x4_t
val0
=
vaddq_f32
(
val00
,
val10
);
float32x4_t
val1
=
vaddq_f32
(
val01
,
val11
);
vst1q_f32
(
y0_ptr
+
j
*
8
,
val0
);
vst1q_f32
(
y0_ptr
+
j
*
8
+
4
,
val1
);
}
y0_ptr
+=
m_cnt8
*
8
;
y_ptr
+=
m_cnt8
*
8
;
for
(
int
j
=
0
;
j
<
m_cnt4
;
++
j
)
{
float32x4_t
val0
=
vld1q_f32
(
y0_ptr
+
j
*
4
);
float32x4_t
val1
=
vld1q_f32
(
y_ptr
+
j
*
4
);
float32x4_t
val
=
vaddq_f32
(
val0
,
val1
);
vst1q_f32
(
y0_ptr
+
j
*
4
,
val
);
}
y0_ptr
+=
m_cnt4
*
4
;
y_ptr
+=
m_cnt4
*
4
;
for
(
int
j
=
0
;
j
<
m_remain
;
++
j
)
{
y0_ptr
[
j
]
+=
y_ptr
[
j
];
}
}
}
valid_ths
=
rdc_ths
;
rdc_ths
=
rdc_ths
>>
1
;
}
if
(
flag_relu
)
{
float
*
in_y
=
y_buf
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
if
(
m_cnt8
>
0
)
{
int
cnt8
=
m_cnt8
;
asm
volatile
(
"vld1.32 {d0-d3}, [%[in_y]]!
\n
"
/* load y to q0, q1 */
"1:
\n
"
"vmax.f32 q2, q0, %q[vzero]
\n
"
/* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]!
\n
"
/* load y to q0 */
"vmax.f32 q3, q1, %q[vzero]
\n
"
/* q1 relu */
"subs %[cnt], %[cnt], #1
\n
"
/* sub cnt */
"vst1.32 {d4-d7}, [%[out_y]]!
\n
"
/* store q0, q1 to y*/
"vld1.32 {d2-d3}, [%[in_y]]!
\n
"
/* load y to q0 */
"bne 1b
\n
"
/* branch to label 1*/
"sub %[in_y], %[in_y], #32
\n
"
/* restore in_y */
:
[
cnt
]
"+r"
(
cnt8
),
[
in_y
]
"+r"
(
in_y
),
[
out_y
]
"+r"
(
y
)
:
[
vzero
]
"w"
(
vzero
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"cc"
,
"memory"
);
}
if
(
m_cnt4
>
0
)
{
int
cnt4
=
m_cnt4
;
asm
volatile
(
"vld1.32 {d0-d1}, [%[in_y]]!
\n
"
/* load y to q0 */
"1:
\n
"
"vmax.f32 q1, q0, %q[vzero]
\n
"
/* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]!
\n
"
/* load y to q0 */
"subs %[cnt], %[cnt], #1
\n
"
/* sub cnt */
"vst1.32 {d2-d3}, [%[out_y]]!
\n
"
/* store q1 to y */
"bne 1b
\n
"
/* branch to label 1*/
"sub %[in_y], %[in_y], #16
\n
"
/* restore in_y */
:
[
cnt
]
"+r"
(
cnt4
),
[
in_y
]
"+r"
(
in_y
),
[
out_y
]
"+r"
(
y
)
:
[
vzero
]
"w"
(
vzero
)
:
"q0"
,
"q1"
,
"cc"
,
"memory"
);
}
for
(
int
r
=
0
;
r
<
m_remain
;
++
r
)
{
y
[
r
]
=
in_y
[
r
]
>
0.
f
?
in_y
[
r
]
:
0.
f
;
}
}
else
{
memcpy
(
y
,
y_buf
,
M
*
sizeof
(
float
));
}
}
#endif // __aarch64__
bool
sgemv
(
const
float
*
A
,
const
float
*
x
,
...
...
@@ -59,33 +549,34 @@ bool sgemv(const float *A,
int
N
,
bool
is_bias
,
const
float
*
bias
,
bool
is_relu
)
{
bool
is_relu
,
const
ARMContext
*
ctx
)
{
if
(
transA
)
{
LOG
(
ERROR
)
<<
" sgemv, transA is not supported now"
;
return
false
;
}
if
(
is_bias
)
{
//! with bias
if
(
is_relu
)
{
//! with relu
sgemv_bias_relu
(
transA
,
M
,
N
,
A
,
x
,
y
,
bias
);
}
else
{
//! without relu
sgemv_bias
(
transA
,
M
,
N
,
A
,
x
,
y
,
bias
);
}
sgemv_trans
(
M
,
N
,
A
,
x
,
y
,
is_bias
,
bias
,
is_relu
,
ctx
);
}
else
{
//! without bias
if
(
is_relu
)
{
//! with relu
sgemv_relu
(
transA
,
M
,
N
,
A
,
x
,
y
);
if
(
is_bias
)
{
//! with bias
if
(
is_relu
)
{
//! with relu
sgemv_bias_relu
(
transA
,
M
,
N
,
A
,
x
,
y
,
bias
);
}
else
{
//! without relu
sgemv_bias
(
transA
,
M
,
N
,
A
,
x
,
y
,
bias
);
}
}
else
{
//! without relu
sgemv
(
transA
,
M
,
N
,
A
,
x
,
y
);
//! without bias
if
(
is_relu
)
{
//! with relu
sgemv_relu
(
transA
,
M
,
N
,
A
,
x
,
y
);
}
else
{
//! without relu
sgemv
(
transA
,
M
,
N
,
A
,
x
,
y
);
}
}
}
return
true
;
}
// clang-format off
//! define compute kernel
#ifdef __aarch64__
#define SGEMV_IN_8 \
...
...
@@ -179,8 +670,8 @@ bool sgemv(const float *A,
"fmla v5.4s, v9.4s, v21.4s \n"
/* mul + add*/
\
"fmla v6.4s, v9.4s, v23.4s \n"
/* mul + add*/
\
"fmla v7.4s, v9.4s, v25.4s \n"
/* mul + add*/
\
"bne 1b \n"
/* jump to main loop */
/* pair add to final
\
result */
\
"bne 1b \n"
/* jump to main loop */
\
/* pair add to final result */
\
"2: \n"
/* reduce to scale */
\
"faddp v16.4s, v0.4s, v0.4s\n"
/* pair add to vector */
\
"faddp s8, v16.2s \n"
/* pair add to scale */
\
...
...
@@ -231,8 +722,8 @@ bool sgemv(const float *A,
"fmla v0.4s, v8.4s, v10.4s \n"
/* mul + add*/
\
"subs %w[cnt], %w[cnt], #1 \n"
/* sub main loop count */
\
"fmla v1.4s, v9.4s, v11.4s \n"
/* mul + add*/
\
"bne 1b \n"
/* jump to main loop */
/* pair add to final
\
result */
\
"bne 1b \n"
/* jump to main loop */
\
/* pair add to final result */
\
"2: \n"
/* reduce to scale */
\
"fadd v9.4s, v0.4s, v1.4s \n"
/* add 2 vector */
\
"faddp v10.4s, v9.4s, v9.4s\n"
/* pair add to vector */
\
...
...
@@ -283,7 +774,7 @@ bool sgemv(const float *A,
"fmax s8, s8, s0 \n"
/* relu */
\
"str s8, [%[out]] \n"
/* save result */
#else //__aarch64__
#else //
__aarch64__
#define SGEMV_IN_4 \
"pld [%[in]] @ preload cache line, input\n" \
...
...
@@ -349,8 +840,8 @@ bool sgemv(const float *A,
"vmla.f32 q1, q5, q9 @ mul add\n" \
"vmla.f32 q2, q5, q11 @ mul add\n" \
"vmla.f32 q3, q5, q13 @ mul add\n" \
"bne 1b @ jump to main loop\n"
/* pair add to final
\
result */
\
"bne 1b @ jump to main loop\n"
\
/* pair add to final result */
\
"2: @ pair add \n" \
"vpadd.f32 d8, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d9, d2, d3 @ pair add, first step\n" \
...
...
@@ -382,13 +873,10 @@ bool sgemv(const float *A,
"vmla.f32 q0, q12, q14 @ mul add\n" \
"vmla.f32 q0, q13, q15 @ mul add\n" \
"subs %[cnt] , #1 @ sub loop count \n" \
"bne 1b @ jump to main loop\n"
/* pair add to \
final result \
*/
\
"bne 1b @ jump to main loop\n" \
"2: @ end processing\n" \
"vpadd.f32 d2, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d0, d2, d2 @ pair add, final step\n"
/* check tails \
*/
\
"vpadd.f32 d0, d2, d2 @ pair add, final step\n"
/*check tails*/
\
"cmp %[tail], #1 @ check whether has mid cols\n" \
"blt 4f @ jump to end\n" \
"3: @ tail loop\n" \
...
...
@@ -422,7 +910,7 @@ bool sgemv(const float *A,
"vmax.f32 d0, d0, d1 @ relu\n" \
"vst1.32 {d0[0]}, [%[out]] @ save result\n"
#endif
// clang-format on
void
sgemv
(
const
bool
transA
,
const
int
M
,
const
int
N
,
...
...
@@ -523,7 +1011,7 @@ void sgemv(const bool transA,
[
tmp4
]
"r"
(
tmp4
)
:
"v0"
,
"v1"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"cc"
,
"memory"
);
}
#else //__aarch64__
#else //
__aarch64__
int
out_cnt
=
M
>>
2
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
out_cnt
;
j
++
)
{
...
...
@@ -579,7 +1067,7 @@ void sgemv(const bool transA,
:
[
out
]
"r"
(
ptr_out
)
:
"q0"
,
"q1"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
#endif //__aarch64__
#endif //
__aarch64__
}
void
sgemv_relu
(
const
bool
transA
,
...
...
@@ -671,7 +1159,7 @@ void sgemv_relu(const bool transA,
:
[
out
]
"r"
(
ptr_out
)
:
"v0"
,
"v1"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"cc"
,
"memory"
);
}
#else //__aarch64__
#else //
__aarch64__
int
out_cnt
=
M
>>
2
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
out_cnt
;
j
++
)
{
...
...
@@ -727,7 +1215,7 @@ void sgemv_relu(const bool transA,
:
[
out
]
"r"
(
ptr_out
)
:
"q0"
,
"q1"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
#endif //__aarch64__
#endif //
__aarch64__
}
void
sgemv_bias
(
const
bool
transA
,
...
...
@@ -822,7 +1310,7 @@ void sgemv_bias(const bool transA,
:
[
out
]
"r"
(
ptr_out
),
[
bias0
]
"r"
(
bias0
)
:
"v0"
,
"v1"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"cc"
,
"memory"
);
}
#else //__aarch64__
#else //
__aarch64__
int
out_cnt
=
M
>>
2
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
out_cnt
;
j
++
)
{
...
...
@@ -887,7 +1375,7 @@ void sgemv_bias(const bool transA,
:
[
out
]
"r"
(
ptr_out
),
[
bias0
]
"r"
(
bias0
)
:
"q0"
,
"q1"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
#endif //__aarch64__
#endif //
__aarch64__
}
void
sgemv_bias_relu
(
const
bool
transA
,
...
...
@@ -980,7 +1468,7 @@ void sgemv_bias_relu(const bool transA,
:
[
out
]
"r"
(
ptr_out
),
[
bias0
]
"r"
(
bias0
)
:
"v0"
,
"v1"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"cc"
,
"memory"
);
}
#else //__aarch64__
#else //
__aarch64__
int
out_cnt
=
M
>>
2
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
out_cnt
;
j
++
)
{
...
...
@@ -1045,7 +1533,7 @@ void sgemv_bias_relu(const bool transA,
:
[
out
]
"r"
(
ptr_out
),
[
bias0
]
"r"
(
bias0
)
:
"q0"
,
"q1"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
#endif //__aarch64__
#endif //
__aarch64__
}
}
// namespace math
...
...
lite/backends/arm/math/sgemv.h
浏览文件 @
dde12f0d
...
...
@@ -15,6 +15,8 @@
#pragma once
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/device_info.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -28,9 +30,10 @@ bool sgemv(const float* A,
bool
transA
,
int
M
,
int
N
,
bool
is_bias
=
false
,
const
float
*
bias
=
nullptr
,
bool
is_relu
=
false
);
bool
is_bias
,
const
float
*
bias
,
bool
is_relu
,
const
ARMContext
*
ctx
);
}
// namespace math
}
// namespace arm
...
...
lite/kernels/arm/fc_compute.cc
浏览文件 @
dde12f0d
...
...
@@ -127,7 +127,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_
,
param
.
bias
!=
nullptr
,
b_data
,
false
);
false
,
&
ctx
);
}
}
}
...
...
lite/kernels/arm/matmul_compute.cc
浏览文件 @
dde12f0d
...
...
@@ -232,7 +232,7 @@ void MatMulCompute::Run() {
int
ldc
=
n_
;
if
(
n_
==
1
)
{
lite
::
arm
::
math
::
sgemv
(
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
);
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
,
&
ctx
);
if
(
fabsf
(
alpha
-
1.
f
)
>
1e-8
f
)
{
for
(
size_t
i
=
0
;
i
<
param
.
Out
->
dims
().
production
();
++
i
)
{
o_data
[
i
]
*=
alpha
;
...
...
lite/kernels/arm/mul_compute.cc
浏览文件 @
dde12f0d
...
...
@@ -48,14 +48,13 @@ void MulCompute::Run() {
CHECK_EQ
(
x_w
,
y_h
)
<<
"x_w must be equal with y_h"
;
k_
=
x_w
;
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
if
(
n_
==
1
)
{
lite
::
arm
::
math
::
sgemv
(
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
);
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
,
&
ctx
);
}
else
{
constexpr
bool
is_tranposed_y
=
false
;
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
int
hblock
=
lite
::
arm
::
math
::
get_hblock
(
&
ctx
);
int
m_round
=
hblock
*
((
m_
+
hblock
-
1
)
/
hblock
);
ctx
.
ExtendWorkspace
(
m_round
*
k_
*
sizeof
(
float
));
...
...
lite/tests/math/CMakeLists.txt
浏览文件 @
dde12f0d
if
((
NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA
)
AND
(
LITE_WITH_X86 OR LITE_WITH_ARM
))
lite_cc_test
(
sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
...
...
lite/tests/math/sgemv_compute_test.cc
0 → 100644
浏览文件 @
dde12f0d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef
paddle
::
lite
::
Tensor
Tensor
;
DEFINE_int32
(
cluster
,
3
,
"cluster id"
);
DEFINE_int32
(
threads
,
1
,
"threads num"
);
DEFINE_int32
(
warmup
,
0
,
"warmup times"
);
DEFINE_int32
(
repeats
,
1
,
"repeats times"
);
DEFINE_bool
(
basic_test
,
true
,
"do all tests"
);
DEFINE_bool
(
check_result
,
true
,
"check the result"
);
DEFINE_int32
(
M
,
512
,
"sgemv: M"
);
DEFINE_int32
(
K
,
512
,
"sgemv: K"
);
DEFINE_bool
(
traA
,
false
,
"gemv: A transpose"
);
DEFINE_bool
(
flag_relu
,
false
,
"do relu"
);
DEFINE_bool
(
flag_bias
,
false
,
"with bias"
);
bool
test_sgemv
(
bool
tra
,
int
m
,
int
k
,
bool
has_bias
,
bool
has_relu
,
int
cls
,
int
ths
)
{
Tensor
ta
;
Tensor
tb
;
Tensor
tc
;
Tensor
tc_basic
;
Tensor
tbias
;
ta
.
Resize
({
m
,
k
});
tb
.
Resize
({
k
,
1
});
tc
.
Resize
({
m
,
1
});
tc_basic
.
Resize
({
m
,
1
});
tbias
.
Resize
({
m
});
ta
.
set_precision
(
PRECISION
(
kFloat
));
tb
.
set_precision
(
PRECISION
(
kFloat
));
tc
.
set_precision
(
PRECISION
(
kFloat
));
tc_basic
.
set_precision
(
PRECISION
(
kFloat
));
tbias
.
set_precision
(
PRECISION
(
kFloat
));
fill_tensor_rand
(
ta
,
-
1.
f
,
1.
f
);
// fill_tensor_const(ta, 1.f);
fill_tensor_rand
(
tb
,
-
1.
f
,
1.
f
);
// fill_tensor_const(tb, 1.f);
fill_tensor_rand
(
tbias
,
-
1.
f
,
1.
f
);
LOG
(
INFO
)
<<
"sgemv M: "
<<
m
<<
", K: "
<<
k
<<
", transA: "
<<
(
tra
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
);
#ifdef LITE_WITH_ARM
auto
da
=
ta
.
mutable_data
<
float
>
();
auto
db
=
tb
.
mutable_data
<
float
>
();
auto
dc
=
tc
.
mutable_data
<
float
>
();
auto
dc_basic
=
tc_basic
.
mutable_data
<
float
>
();
auto
dbias
=
tbias
.
mutable_data
<
float
>
();
if
(
FLAGS_check_result
)
{
basic_gemv
(
m
,
k
,
da
,
db
,
dbias
,
dc_basic
,
1.
f
,
0.
f
,
tra
,
has_bias
,
has_relu
);
}
paddle
::
lite
::
Timer
t0
;
//! compute
double
ops
=
2.0
*
m
*
k
;
std
::
unique_ptr
<
paddle
::
lite
::
KernelContext
>
ctx1
(
new
paddle
::
lite
::
KernelContext
);
auto
&
ctx
=
ctx1
->
As
<
paddle
::
lite
::
ARMContext
>
();
ctx
.
SetRunMode
(
static_cast
<
paddle
::
lite_api
::
PowerMode
>
(
cls
),
ths
);
/// warmup
for
(
int
j
=
0
;
j
<
FLAGS_warmup
;
++
j
)
{
paddle
::
lite
::
arm
::
math
::
sgemv
(
da
,
db
,
dc
,
tra
,
m
,
k
,
has_bias
,
dbias
,
has_relu
,
&
ctx
);
}
t0
.
clear
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
t0
.
start
();
paddle
::
lite
::
arm
::
math
::
sgemv
(
da
,
db
,
dc
,
tra
,
m
,
k
,
has_bias
,
dbias
,
has_relu
,
&
ctx
);
t0
.
end
();
}
LOG
(
INFO
)
<<
"gemv output: M: "
<<
m
<<
", K: "
<<
k
<<
", cluster: "
<<
cls
<<
", threads: "
<<
ths
<<
", GOPS: "
<<
ops
*
1e-9
f
<<
" GOPS, avg time: "
<<
t0
.
get_average_ms
()
<<
" ms, min time: "
<<
t0
.
get_min_time
()
<<
" ms, mean GOPs: "
<<
ops
*
1e-6
f
/
t0
.
get_average_ms
()
<<
" GOPs, max GOPs: "
<<
ops
*
1e-6
f
/
t0
.
get_min_time
()
<<
" GOPs"
;
if
(
FLAGS_check_result
)
{
double
max_ratio
=
0
;
double
max_diff
=
0
;
/// fp32 result
tensor_cmp_host
(
tc_basic
,
tc
,
max_ratio
,
max_diff
);
LOG
(
INFO
)
<<
"compare result, max diff: "
<<
max_diff
<<
", max ratio: "
<<
max_ratio
;
if
(
std
::
abs
(
max_ratio
)
>
1e-4
f
&&
std
::
abs
(
max_diff
)
>
5e-5
f
)
{
Tensor
tdiff
;
tdiff
.
set_precision
(
PRECISION
(
kFloat
));
tdiff
.
Resize
(
tc
.
dims
());
tensor_diff
(
tc_basic
,
tc
,
tdiff
);
LOG
(
INFO
)
<<
"basic result: "
;
print_tensor
(
tc_basic
);
LOG
(
INFO
)
<<
"saber result: "
;
print_tensor
(
tc
);
LOG
(
INFO
)
<<
"diff result: "
;
print_tensor
(
tdiff
);
return
false
;
}
}
#endif
return
true
;
}
TEST
(
TestLiteSgemv
,
Sgemv
)
{
if
(
FLAGS_basic_test
)
{
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
LOG
(
INFO
)
<<
"run basic sgemv test"
;
for
(
auto
&
m
:
{
1
,
3
,
8
,
21
,
32
,
397
})
{
for
(
auto
&
k
:
{
1
,
3
,
8
,
17
,
59
,
234
})
{
for
(
auto
&
tra
:
{
true
,
false
})
{
for
(
auto
&
has_bias
:
{
false
,
true
})
{
for
(
auto
&
has_relu
:
{
false
,
true
})
{
for
(
auto
&
th
:
{
1
,
2
,
4
})
{
auto
flag
=
test_sgemv
(
tra
,
m
,
k
,
has_bias
,
has_relu
,
FLAGS_cluster
,
th
);
if
(
flag
)
{
LOG
(
INFO
)
<<
"test m = "
<<
m
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", trans A: "
<<
(
tra
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
" passed
\n
"
;
}
else
{
LOG
(
FATAL
)
<<
"test m = "
<<
m
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", trans A: "
<<
(
tra
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
" failed
\n
"
;
}
}
}
}
}
}
}
}
}
TEST
(
TestSgemvCustom
,
Sgemv_custom
)
{
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
auto
flag
=
test_sgemv
(
FLAGS_traA
,
FLAGS_M
,
FLAGS_K
,
FLAGS_flag_bias
,
FLAGS_flag_relu
,
FLAGS_cluster
,
FLAGS_threads
);
if
(
!
flag
)
{
LOG
(
FATAL
)
<<
"test m = "
<<
FLAGS_M
<<
", k="
<<
FLAGS_K
<<
", trans A: "
<<
FLAGS_traA
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" failed!!"
;
}
LOG
(
INFO
)
<<
"test m = "
<<
FLAGS_M
<<
", k="
<<
FLAGS_K
<<
", trans A: "
<<
FLAGS_traA
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" passed!!"
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录