Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
9004e510
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9004e510
编写于
10月 25, 2018
作者:
X
xiebaiyuan
提交者:
GitHub
10月 25, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1208 from hjchen2/dev-latest
Fix load quant and dequant ops
上级
a11a3e94
b2dd1faf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
47 addition
and
33 deletion
+47
-33
src/framework/load_ops.h
src/framework/load_ops.h
+4
-0
src/operators/kernel/arm/quantize_kernel.cpp
src/operators/kernel/arm/quantize_kernel.cpp
+36
-24
src/operators/kernel/central-arm-func/elementwise_add_arm_func.h
...rators/kernel/central-arm-func/elementwise_add_arm_func.h
+1
-0
test/net/test_googlenet.cpp
test/net/test_googlenet.cpp
+6
-9
未找到文件。
src/framework/load_ops.h
浏览文件 @
9004e510
...
...
@@ -224,5 +224,9 @@ LOAD_FUSION_MATCHER(fusion_conv_bn);
#ifdef ELEMENTWISESUB_OP
LOAD_OP1
(
elementwise_sub
,
CPU
)
#endif
#ifdef QUANT_OP
LOAD_OP1
(
quantize
,
CPU
);
#endif
#ifdef DEQUANT_OP
LOAD_OP1
(
dequantize
,
CPU
);
#endif
src/operators/kernel/arm/quantize_kernel.cpp
浏览文件 @
9004e510
...
...
@@ -135,11 +135,15 @@ static void quantize_round_to_even(const Tensor *input, const float scale,
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t
loop
=
size
>>
4
;
size_t
remain
=
size
&
0xF
;
#pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
loop
;
++
i
)
{
float32x4_t
r0
=
vld1q_f32
(
x
);
float32x4_t
r1
=
vld1q_f32
(
x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
x
+
12
);
const
float
*
local_x
=
x
+
(
i
<<
4
);
int8_t
*
local_y
=
y
+
(
i
<<
4
);
float32x4_t
r0
=
vld1q_f32
(
local_x
);
float32x4_t
r1
=
vld1q_f32
(
local_x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
local_x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
local_x
+
12
);
r0
=
vmulq_n_f32
(
r0
,
scale
);
r1
=
vmulq_n_f32
(
r1
,
scale
);
r2
=
vmulq_n_f32
(
r2
,
scale
);
...
...
@@ -156,12 +160,12 @@ static void quantize_round_to_even(const Tensor *input, const float scale,
int16x8_t
q6
=
vcombine_s16
(
d2
,
d3
);
int8x8_t
d5
=
vmovn_s16
(
q5
);
int8x8_t
d6
=
vmovn_s16
(
q6
);
vst1_s8
(
y
,
d5
);
vst1_s8
(
y
+
8
,
d6
);
x
+=
16
;
y
+=
16
;
vst1_s8
(
local_y
,
d5
);
vst1_s8
(
local_y
+
8
,
d6
);
}
size
=
remain
;
x
+=
(
loop
<<
4
);
y
+=
(
loop
<<
4
);
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
float
value
=
x
[
i
]
*
scale
;
...
...
@@ -187,11 +191,15 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t
loop
=
size
>>
4
;
size_t
remain
=
size
&
0xF
;
#pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
loop
;
++
i
)
{
float32x4_t
r0
=
vld1q_f32
(
x
);
float32x4_t
r1
=
vld1q_f32
(
x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
x
+
12
);
const
float
*
local_x
=
x
+
(
i
<<
4
);
int8_t
*
local_y
=
y
+
(
i
<<
4
);
float32x4_t
r0
=
vld1q_f32
(
local_x
);
float32x4_t
r1
=
vld1q_f32
(
local_x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
local_x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
local_x
+
12
);
r0
=
vmulq_n_f32
(
r0
,
scale
);
r1
=
vmulq_n_f32
(
r1
,
scale
);
r2
=
vmulq_n_f32
(
r2
,
scale
);
...
...
@@ -208,12 +216,12 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
int16x8_t
q6
=
vcombine_s16
(
d2
,
d3
);
int8x8_t
d5
=
vmovn_s16
(
q5
);
int8x8_t
d6
=
vmovn_s16
(
q6
);
vst1_s8
(
y
,
d5
);
vst1_s8
(
y
+
8
,
d6
);
x
+=
16
;
y
+=
16
;
vst1_s8
(
local_y
,
d5
);
vst1_s8
(
local_y
+
8
,
d6
);
}
size
=
remain
;
x
+=
(
loop
<<
4
);
y
+=
(
loop
<<
4
);
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
y
[
i
]
=
trunc
(
x
[
i
]
*
scale
);
...
...
@@ -228,11 +236,15 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t
loop
=
size
>>
4
;
size_t
remain
=
size
&
0xF
;
#pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
loop
;
++
i
)
{
float32x4_t
r0
=
vld1q_f32
(
x
);
float32x4_t
r1
=
vld1q_f32
(
x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
x
+
12
);
const
float
*
local_x
=
x
+
(
i
<<
4
);
int8_t
*
local_y
=
y
+
(
i
<<
4
);
float32x4_t
r0
=
vld1q_f32
(
local_x
);
float32x4_t
r1
=
vld1q_f32
(
local_x
+
4
);
float32x4_t
r2
=
vld1q_f32
(
local_x
+
8
);
float32x4_t
r3
=
vld1q_f32
(
local_x
+
12
);
r0
=
vmulq_n_f32
(
r0
,
scale
);
r1
=
vmulq_n_f32
(
r1
,
scale
);
r2
=
vmulq_n_f32
(
r2
,
scale
);
...
...
@@ -249,12 +261,12 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
int16x8_t
q6
=
vcombine_s16
(
d2
,
d3
);
int8x8_t
d5
=
vmovn_s16
(
q5
);
int8x8_t
d6
=
vmovn_s16
(
q6
);
vst1_s8
(
y
,
d5
);
vst1_s8
(
y
+
8
,
d6
);
x
+=
16
;
y
+=
16
;
vst1_s8
(
local_y
,
d5
);
vst1_s8
(
local_y
+
8
,
d6
);
}
size
=
remain
;
x
+=
(
loop
<<
4
);
y
+=
(
loop
<<
4
);
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
y
[
i
]
=
round
(
x
[
i
]
*
scale
);
...
...
src/operators/kernel/central-arm-func/elementwise_add_arm_func.h
浏览文件 @
9004e510
...
...
@@ -58,6 +58,7 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> ¶m) {
const
float
*
input_data
=
input_x
->
data
<
float
>
();
float
*
output_data
=
Out
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
size_t
offset
=
(
i
*
channels
+
j
)
*
elementwise_num
;
const
float
*
input
=
input_data
+
offset
;
...
...
test/net/test_googlenet.cpp
浏览文件 @
9004e510
...
...
@@ -25,8 +25,8 @@ int main() {
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
CPU
>
paddle_mobile
;
#endif
paddle_mobile
.
SetThreadNum
(
1
);
bool
optimize
=
fals
e
;
paddle_mobile
.
SetThreadNum
(
4
);
bool
optimize
=
tru
e
;
auto
time1
=
time
();
if
(
paddle_mobile
.
Load
(
g_googlenet
,
optimize
))
{
auto
time2
=
time
();
...
...
@@ -35,10 +35,10 @@ int main() {
std
::
vector
<
float
>
output
;
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
dims
);
//
//
预热十次
//
for (int i = 0; i < 10; ++i) {
//
output = paddle_mobile.Predict(input, dims);
//
}
// 预热十次
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
output
=
paddle_mobile
.
Predict
(
input
,
dims
);
}
auto
time3
=
time
();
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
output
=
paddle_mobile
.
Predict
(
input
,
dims
);
...
...
@@ -47,9 +47,6 @@ int main() {
std
::
cout
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
10
<<
"ms"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
output
.
size
();
++
i
)
{
DLOG
<<
"result["
<<
i
<<
"] = "
<<
output
[
i
];
}
}
return
0
;
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录