Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
dde12f0d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dde12f0d
编写于
11月 20, 2019
作者:
Y
yiicy
提交者:
GitHub
11月 20, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ARM] sgemv support transA, test=develop (#2453)
* [ARM] sgemv support transA, test=develop * add sgemv ut, test=develop
上级
b094b2b6
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
739 addition
and
51 deletion
+739
-51
lite/backends/arm/math/conv_impl.cc
lite/backends/arm/math/conv_impl.cc
+4
-2
lite/backends/arm/math/sgemv.cc
lite/backends/arm/math/sgemv.cc
+529
-41
lite/backends/arm/math/sgemv.h
lite/backends/arm/math/sgemv.h
+6
-3
lite/kernels/arm/fc_compute.cc
lite/kernels/arm/fc_compute.cc
+2
-1
lite/kernels/arm/matmul_compute.cc
lite/kernels/arm/matmul_compute.cc
+1
-1
lite/kernels/arm/mul_compute.cc
lite/kernels/arm/mul_compute.cc
+2
-3
lite/tests/math/CMakeLists.txt
lite/tests/math/CMakeLists.txt
+1
-0
lite/tests/math/sgemv_compute_test.cc
lite/tests/math/sgemv_compute_test.cc
+194
-0
未找到文件。
lite/backends/arm/math/conv_impl.cc
浏览文件 @
dde12f0d
...
@@ -202,7 +202,8 @@ void conv1x1s1_gemm(const float* i_data,
...
@@ -202,7 +202,8 @@ void conv1x1s1_gemm(const float* i_data,
k
,
k
,
flag_bias
,
flag_bias
,
bias_group
,
bias_group
,
flag_relu
);
flag_relu
,
ctx
);
}
else
{
}
else
{
sgemm_prepack
(
false
,
sgemm_prepack
(
false
,
m
,
m
,
...
@@ -395,7 +396,8 @@ void conv_im2col_gemm(const float* i_data,
...
@@ -395,7 +396,8 @@ void conv_im2col_gemm(const float* i_data,
k
,
k
,
flag_bias
,
flag_bias
,
bias_group
,
bias_group
,
flag_relu
);
flag_relu
,
ctx
);
}
else
{
}
else
{
int
ldb
=
n
;
int
ldb
=
n
;
sgemm_prepack
(
false
,
sgemm_prepack
(
false
,
...
...
lite/backends/arm/math/sgemv.cc
浏览文件 @
dde12f0d
此差异已折叠。
点击以展开。
lite/backends/arm/math/sgemv.h
浏览文件 @
dde12f0d
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#pragma once
#pragma once
#include <cmath>
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/device_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -28,9 +30,10 @@ bool sgemv(const float* A,
...
@@ -28,9 +30,10 @@ bool sgemv(const float* A,
bool
transA
,
bool
transA
,
int
M
,
int
M
,
int
N
,
int
N
,
bool
is_bias
=
false
,
bool
is_bias
,
const
float
*
bias
=
nullptr
,
const
float
*
bias
,
bool
is_relu
=
false
);
bool
is_relu
,
const
ARMContext
*
ctx
);
}
// namespace math
}
// namespace math
}
// namespace arm
}
// namespace arm
...
...
lite/kernels/arm/fc_compute.cc
浏览文件 @
dde12f0d
...
@@ -127,7 +127,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
...
@@ -127,7 +127,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_
,
k_
,
param
.
bias
!=
nullptr
,
param
.
bias
!=
nullptr
,
b_data
,
b_data
,
false
);
false
,
&
ctx
);
}
}
}
}
}
}
...
...
lite/kernels/arm/matmul_compute.cc
浏览文件 @
dde12f0d
...
@@ -232,7 +232,7 @@ void MatMulCompute::Run() {
...
@@ -232,7 +232,7 @@ void MatMulCompute::Run() {
int
ldc
=
n_
;
int
ldc
=
n_
;
if
(
n_
==
1
)
{
if
(
n_
==
1
)
{
lite
::
arm
::
math
::
sgemv
(
lite
::
arm
::
math
::
sgemv
(
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
);
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
,
&
ctx
);
if
(
fabsf
(
alpha
-
1.
f
)
>
1e-8
f
)
{
if
(
fabsf
(
alpha
-
1.
f
)
>
1e-8
f
)
{
for
(
size_t
i
=
0
;
i
<
param
.
Out
->
dims
().
production
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param
.
Out
->
dims
().
production
();
++
i
)
{
o_data
[
i
]
*=
alpha
;
o_data
[
i
]
*=
alpha
;
...
...
lite/kernels/arm/mul_compute.cc
浏览文件 @
dde12f0d
...
@@ -48,14 +48,13 @@ void MulCompute::Run() {
...
@@ -48,14 +48,13 @@ void MulCompute::Run() {
CHECK_EQ
(
x_w
,
y_h
)
<<
"x_w must be equal with y_h"
;
CHECK_EQ
(
x_w
,
y_h
)
<<
"x_w must be equal with y_h"
;
k_
=
x_w
;
k_
=
x_w
;
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
if
(
n_
==
1
)
{
if
(
n_
==
1
)
{
lite
::
arm
::
math
::
sgemv
(
lite
::
arm
::
math
::
sgemv
(
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
);
x_data
,
y_data
,
o_data
,
false
,
m_
,
k_
,
false
,
nullptr
,
false
,
&
ctx
);
}
else
{
}
else
{
constexpr
bool
is_tranposed_y
=
false
;
constexpr
bool
is_tranposed_y
=
false
;
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
int
hblock
=
lite
::
arm
::
math
::
get_hblock
(
&
ctx
);
int
hblock
=
lite
::
arm
::
math
::
get_hblock
(
&
ctx
);
int
m_round
=
hblock
*
((
m_
+
hblock
-
1
)
/
hblock
);
int
m_round
=
hblock
*
((
m_
+
hblock
-
1
)
/
hblock
);
ctx
.
ExtendWorkspace
(
m_round
*
k_
*
sizeof
(
float
));
ctx
.
ExtendWorkspace
(
m_round
*
k_
*
sizeof
(
float
));
...
...
lite/tests/math/CMakeLists.txt
浏览文件 @
dde12f0d
if
((
NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA
)
AND
(
LITE_WITH_X86 OR LITE_WITH_ARM
))
if
((
NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA
)
AND
(
LITE_WITH_X86 OR LITE_WITH_ARM
))
lite_cc_test
(
sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
...
...
lite/tests/math/sgemv_compute_test.cc
0 → 100644
浏览文件 @
dde12f0d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef
paddle
::
lite
::
Tensor
Tensor
;
DEFINE_int32
(
cluster
,
3
,
"cluster id"
);
DEFINE_int32
(
threads
,
1
,
"threads num"
);
DEFINE_int32
(
warmup
,
0
,
"warmup times"
);
DEFINE_int32
(
repeats
,
1
,
"repeats times"
);
DEFINE_bool
(
basic_test
,
true
,
"do all tests"
);
DEFINE_bool
(
check_result
,
true
,
"check the result"
);
DEFINE_int32
(
M
,
512
,
"sgemv: M"
);
DEFINE_int32
(
K
,
512
,
"sgemv: K"
);
DEFINE_bool
(
traA
,
false
,
"gemv: A transpose"
);
DEFINE_bool
(
flag_relu
,
false
,
"do relu"
);
DEFINE_bool
(
flag_bias
,
false
,
"with bias"
);
bool
test_sgemv
(
bool
tra
,
int
m
,
int
k
,
bool
has_bias
,
bool
has_relu
,
int
cls
,
int
ths
)
{
Tensor
ta
;
Tensor
tb
;
Tensor
tc
;
Tensor
tc_basic
;
Tensor
tbias
;
ta
.
Resize
({
m
,
k
});
tb
.
Resize
({
k
,
1
});
tc
.
Resize
({
m
,
1
});
tc_basic
.
Resize
({
m
,
1
});
tbias
.
Resize
({
m
});
ta
.
set_precision
(
PRECISION
(
kFloat
));
tb
.
set_precision
(
PRECISION
(
kFloat
));
tc
.
set_precision
(
PRECISION
(
kFloat
));
tc_basic
.
set_precision
(
PRECISION
(
kFloat
));
tbias
.
set_precision
(
PRECISION
(
kFloat
));
fill_tensor_rand
(
ta
,
-
1.
f
,
1.
f
);
// fill_tensor_const(ta, 1.f);
fill_tensor_rand
(
tb
,
-
1.
f
,
1.
f
);
// fill_tensor_const(tb, 1.f);
fill_tensor_rand
(
tbias
,
-
1.
f
,
1.
f
);
LOG
(
INFO
)
<<
"sgemv M: "
<<
m
<<
", K: "
<<
k
<<
", transA: "
<<
(
tra
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
);
#ifdef LITE_WITH_ARM
auto
da
=
ta
.
mutable_data
<
float
>
();
auto
db
=
tb
.
mutable_data
<
float
>
();
auto
dc
=
tc
.
mutable_data
<
float
>
();
auto
dc_basic
=
tc_basic
.
mutable_data
<
float
>
();
auto
dbias
=
tbias
.
mutable_data
<
float
>
();
if
(
FLAGS_check_result
)
{
basic_gemv
(
m
,
k
,
da
,
db
,
dbias
,
dc_basic
,
1.
f
,
0.
f
,
tra
,
has_bias
,
has_relu
);
}
paddle
::
lite
::
Timer
t0
;
//! compute
double
ops
=
2.0
*
m
*
k
;
std
::
unique_ptr
<
paddle
::
lite
::
KernelContext
>
ctx1
(
new
paddle
::
lite
::
KernelContext
);
auto
&
ctx
=
ctx1
->
As
<
paddle
::
lite
::
ARMContext
>
();
ctx
.
SetRunMode
(
static_cast
<
paddle
::
lite_api
::
PowerMode
>
(
cls
),
ths
);
/// warmup
for
(
int
j
=
0
;
j
<
FLAGS_warmup
;
++
j
)
{
paddle
::
lite
::
arm
::
math
::
sgemv
(
da
,
db
,
dc
,
tra
,
m
,
k
,
has_bias
,
dbias
,
has_relu
,
&
ctx
);
}
t0
.
clear
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
t0
.
start
();
paddle
::
lite
::
arm
::
math
::
sgemv
(
da
,
db
,
dc
,
tra
,
m
,
k
,
has_bias
,
dbias
,
has_relu
,
&
ctx
);
t0
.
end
();
}
LOG
(
INFO
)
<<
"gemv output: M: "
<<
m
<<
", K: "
<<
k
<<
", cluster: "
<<
cls
<<
", threads: "
<<
ths
<<
", GOPS: "
<<
ops
*
1e-9
f
<<
" GOPS, avg time: "
<<
t0
.
get_average_ms
()
<<
" ms, min time: "
<<
t0
.
get_min_time
()
<<
" ms, mean GOPs: "
<<
ops
*
1e-6
f
/
t0
.
get_average_ms
()
<<
" GOPs, max GOPs: "
<<
ops
*
1e-6
f
/
t0
.
get_min_time
()
<<
" GOPs"
;
if
(
FLAGS_check_result
)
{
double
max_ratio
=
0
;
double
max_diff
=
0
;
/// fp32 result
tensor_cmp_host
(
tc_basic
,
tc
,
max_ratio
,
max_diff
);
LOG
(
INFO
)
<<
"compare result, max diff: "
<<
max_diff
<<
", max ratio: "
<<
max_ratio
;
if
(
std
::
abs
(
max_ratio
)
>
1e-4
f
&&
std
::
abs
(
max_diff
)
>
5e-5
f
)
{
Tensor
tdiff
;
tdiff
.
set_precision
(
PRECISION
(
kFloat
));
tdiff
.
Resize
(
tc
.
dims
());
tensor_diff
(
tc_basic
,
tc
,
tdiff
);
LOG
(
INFO
)
<<
"basic result: "
;
print_tensor
(
tc_basic
);
LOG
(
INFO
)
<<
"saber result: "
;
print_tensor
(
tc
);
LOG
(
INFO
)
<<
"diff result: "
;
print_tensor
(
tdiff
);
return
false
;
}
}
#endif
return
true
;
}
TEST
(
TestLiteSgemv
,
Sgemv
)
{
if
(
FLAGS_basic_test
)
{
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
LOG
(
INFO
)
<<
"run basic sgemv test"
;
for
(
auto
&
m
:
{
1
,
3
,
8
,
21
,
32
,
397
})
{
for
(
auto
&
k
:
{
1
,
3
,
8
,
17
,
59
,
234
})
{
for
(
auto
&
tra
:
{
true
,
false
})
{
for
(
auto
&
has_bias
:
{
false
,
true
})
{
for
(
auto
&
has_relu
:
{
false
,
true
})
{
for
(
auto
&
th
:
{
1
,
2
,
4
})
{
auto
flag
=
test_sgemv
(
tra
,
m
,
k
,
has_bias
,
has_relu
,
FLAGS_cluster
,
th
);
if
(
flag
)
{
LOG
(
INFO
)
<<
"test m = "
<<
m
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", trans A: "
<<
(
tra
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
" passed
\n
"
;
}
else
{
LOG
(
FATAL
)
<<
"test m = "
<<
m
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", trans A: "
<<
(
tra
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
" failed
\n
"
;
}
}
}
}
}
}
}
}
}
TEST
(
TestSgemvCustom
,
Sgemv_custom
)
{
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
auto
flag
=
test_sgemv
(
FLAGS_traA
,
FLAGS_M
,
FLAGS_K
,
FLAGS_flag_bias
,
FLAGS_flag_relu
,
FLAGS_cluster
,
FLAGS_threads
);
if
(
!
flag
)
{
LOG
(
FATAL
)
<<
"test m = "
<<
FLAGS_M
<<
", k="
<<
FLAGS_K
<<
", trans A: "
<<
FLAGS_traA
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" failed!!"
;
}
LOG
(
INFO
)
<<
"test m = "
<<
FLAGS_M
<<
", k="
<<
FLAGS_K
<<
", trans A: "
<<
FLAGS_traA
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" passed!!"
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录