Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
a72f988a
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a72f988a
编写于
10月 09, 2019
作者:
L
Liufang Sang
提交者:
whs
10月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PaddleSlim]Yolov3 quantization demo (#3440)
上级
b860c3ba
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
844 addition
and
0 deletion
+844
-0
PaddleCV/PaddleDetection/slim/quantization/README.md
PaddleCV/PaddleDetection/slim/quantization/README.md
+130
-0
PaddleCV/PaddleDetection/slim/quantization/compress.py
PaddleCV/PaddleDetection/slim/quantization/compress.py
+267
-0
PaddleCV/PaddleDetection/slim/quantization/eval.py
PaddleCV/PaddleDetection/slim/quantization/eval.py
+184
-0
PaddleCV/PaddleDetection/slim/quantization/freeze.py
PaddleCV/PaddleDetection/slim/quantization/freeze.py
+243
-0
PaddleCV/PaddleDetection/slim/quantization/yolov3_mobilenet_v1_slim.yaml
...Detection/slim/quantization/yolov3_mobilenet_v1_slim.yaml
+20
-0
未找到文件。
PaddleCV/PaddleDetection/slim/quantization/README.md
0 → 100644
浏览文件 @
a72f988a
>运行该示例前请安装Paddle1.6或更高版本
# 检测模型量化压缩示例
## 概述
该示例使用PaddleSlim提供的
[
量化压缩策略
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#1-quantization-aware-training%E9%87%8F%E5%8C%96%E4%BB%8B%E7%BB%8D
)
对分类模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
-
[
检测模型的常规训练方法
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleDetection
)
-
[
PaddleSlim使用文档
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md
)
## 配置文件说明
关于配置文件如何编写您可以参考:
-
[
PaddleSlim配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#122-%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E7%9A%84%E4%BD%BF%E7%94%A8
)
-
[
量化策略配置文件编写说明
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/usage.md#21-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83
)
其中save_out_nodes需要得到检测结果的Variable的名称,下面介绍如何确定save_out_nodes的参数
以MobileNet V1为例,可在compress.py中构建好网络之后,直接打印Variable得到Variable的名称信息。
代码示例:
```
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)
# print(eval_values)
```
根据运行结果可看到Variable的名字为:
`multiclass_nms_0.tmp_0`
。
## 训练
根据
[
PaddleCV/PaddleDetection/tools/train.py
](
https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/PaddleDetection/tools/train.py
)
编写压缩脚本compress.py。
在该脚本中定义了Compressor对象,用于执行压缩任务。
通过
`python compress.py --help`
查看可配置参数,简述如下:
-
config: 检测库的配置,其中配置了训练超参数、数据集信息等。
-
slim_file: PaddleSlim的配置文件,参见
[
配置文件说明
](
#配置文件说明
)
。
您可以通过运行脚本
`run.sh`
运行该示例,请确保已正确下载
[
pretrained model
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD
)
。
### 训练时的模型结构
这部分介绍来源于
[
量化low-level API介绍
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api#1-%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83low-level-apis%E4%BB%8B%E7%BB%8D
)
。
PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass以及TransformForMobilePass。在训练时,对网络应用了QuantizationTransformPass,作用是在网络中的conv2d、depthwise_conv2d、mul等算子的各个输入前插入连续的量化op和反量化op,并改变相应反向算子的某些输入。示例图如下:
<p
align=
"center"
>
<img
src=
"./images/TransformPass.png"
height=
400
width=
520
hspace=
'10'
/>
<br
/>
<strong>
图1:应用QuantizationTransformPass后的结果
</strong>
</p>
### 保存断点(checkpoint)
如果在配置文件中设置了
`checkpoint_path`
, 则在压缩任务执行过程中会自动保存断点,当任务异常中断时,
重启任务会自动从
`checkpoint_path`
路径下按数字顺序加载最新的checkpoint文件。如果不想让重启的任务从断点恢复,
需要修改配置文件中的
`checkpoint_path`
,或者将
`checkpoint_path`
路径下文件清空。
>注意:配置文件中的信息不会保存在断点中,重启前对配置文件的修改将会生效。
## 评估
如果在配置文件中设置了
`checkpoint_path`
,则每个epoch会保存一个量化后的用于评估的模型,
该模型会保存在
`${checkpoint_path}/${epoch_id}/eval_model/`
路径下,包含
`__model__`
和
`__params__`
两个文件。
其中,
`__model__`
用于保存模型结构信息,
`__params__`
用于保存参数(parameters)信息。模型结构和训练时一样。
如果不需要保存评估模型,可以在定义Compressor对象时,将
`save_eval_model`
选项设置为False(默认为True)。
脚本
<a
href=
"eval.py"
>
slim/quantization/eval.py
</a>
中为使用该模型在评估数据集上做评估的示例。
## 预测
如果在配置文件的量化策略中设置了
`float_model_save_path`
,
`int8_model_save_path`
,
`mobile_model_save_path`
, 在训练结束后,会保存模型量化压缩之后用于预测的模型。接下来介绍这三种预测模型的区别。
### float预测模型
在介绍量化训练时的模型结构时介绍了PaddlePaddle框架中有四个和量化相关的IrPass, 分别是QuantizationTransformPass、QuantizationFreezePass、ConvertToInt8Pass以及TransformForMobilePass。float预测模型是在应用QuantizationFreezePass并删除eval_program中多余的operators之后,保存的模型。
QuantizationFreezePass主要用于改变IrGraph中量化op和反量化op的顺序,即将类似图1中的量化op和反量化op顺序改变为图2中的布局。除此之外,QuantizationFreezePass还会将
`conv2d`
、
`depthwise_conv2d`
、
`mul`
等算子的权重离线量化为int8_t范围内的值(但数据类型仍为float32),以减少预测过程中对权重的量化操作,示例如图2:
<p
align=
"center"
>
<img
src=
"./images/FreezePass.png"
height=
400
width=
420
hspace=
'10'
/>
<br
/>
<strong>
图2:应用QuantizationFreezePass后的结果
</strong>
</p>
### int8预测模型
在对训练网络进行QuantizationFreezePass之后,执行ConvertToInt8Pass,
其主要目的是将执行完QuantizationFreezePass后输出的权重类型由
`FP32`
更改为
`INT8`
。换言之,用户可以选择将量化后的权重保存为float32类型(不执行ConvertToInt8Pass)或者int8_t类型(执行ConvertToInt8Pass),示例如图3:
<p
align=
"center"
>
<img
src=
"./images/ConvertToInt8Pass.png"
height=
400
width=
400
hspace=
'10'
/>
<br
/>
<strong>
图3:应用ConvertToInt8Pass后的结果
</strong>
</p>
### mobile预测模型
经TransformForMobilePass转换后,用户可得到兼容
[
paddle-lite
](
https://github.com/PaddlePaddle/Paddle-Lite
)
移动端预测库的量化模型。paddle-mobile中的量化op和反量化op的名称分别为
`quantize`
和
`dequantize`
。
`quantize`
算子和PaddlePaddle框架中的
`fake_quantize_abs_max`
算子簇的功能类似,
`dequantize`
算子和PaddlePaddle框架中的
`fake_dequantize_max_abs`
算子簇的功能相同。若选择paddle-mobile执行量化训练输出的模型,则需要将
`fake_quantize_abs_max`
等算子改为
`quantize`
算子以及将
`fake_dequantize_max_abs`
等算子改为
`dequantize`
算子,示例如图4:
<p
align=
"center"
>
<img
src=
"./images/TransformForMobilePass.png"
height=
400
width=
400
hspace=
'10'
/>
<br
/>
<strong>
图4:应用TransformForMobilePass后的结果
</strong>
</p>
### python预测
### PaddleLite预测
float预测模型可使用PaddleLite进行加载预测,可参见教程
[
Paddle-Lite如何加载运行量化模型
](
https://github.com/PaddlePaddle/Paddle-Lite/wiki/model_quantization
)
## 从评估模型保存预测模型
从
[
配置文件说明
](
#配置文件说明
)
中可以看到,在
`end_epoch`
时将保存可用于预测的
`float`
,
`int8`
,
`mobile`
模型,但是在训练之前不能准确地保存结果最好的epoch的结果,因此,提供了从
`${checkpoint_path}/${epoch_id}/eval_model/`
下保存的评估模型转化为预测模型的接口
`freeze.py `
, 需要配置的参数为:
-
model_path, 加载的模型路径,
`为${checkpoint_path}/${epoch_id}/eval_model/`
-
weight_quant_type 模型参数的量化方式,和配置文件中的类型保持一致
-
save_path
`float`
,
`int8`
,
`mobile`
模型的保存路径,分别为
`${save_path}/float/`
,
`${save_path}/int8/`
,
`${save_path}/mobile/`
## 示例结果
### MobileNetV1
| weight量化方式 | activation量化方式| Box ap |Paddle Fluid inference time(ms)| Paddle Lite inference time(ms)|
|---|---|---|---|---|
|baseline|- |76.2%|- |-|
|abs_max|abs_max|- |- |-|
|abs_max|moving_average_abs_max|- |- |-|
|channel_wise_abs_max|abs_max|- |- |-|
>训练超参:
## FAQ
PaddleCV/PaddleDetection/slim/quantization/compress.py
0 → 100644
浏览文件 @
a72f988a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
import
datetime
from
collections
import
deque
import
sys
sys
.
path
.
append
(
"../../"
)
from
paddle.fluid.contrib.slim
import
Compressor
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid
import
core
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
if
len
(
cls
)
!=
0
:
values
=
[]
for
i
in
range
(
len
(
cls
)):
_
,
accum_map
=
cls
[
i
].
get_map_var
()
cls
[
i
].
reset
(
exe
)
values
.
append
(
accum_map
)
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
values
[
0
],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
else
:
raise
ValueError
(
"'architecture' not specified in config file."
)
merge_config
(
FLAGS
.
opt
)
if
'log_iter'
not
in
cfg
:
cfg
.
log_iter
=
20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'train_feed'
not
in
cfg
:
train_feed
=
create
(
main_arch
+
'TrainFeed'
)
else
:
train_feed
=
create
(
cfg
.
train_feed
)
if
'eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
train_pyreader
,
feed_vars
=
create_feed
(
train_feed
)
train_fetches
=
model
.
train
(
feed_vars
)
loss
=
train_fetches
[
'loss'
]
lr
=
lr_builder
()
optimizer
=
optim_builder
(
lr
)
optimizer
.
minimize
(
loss
)
train_reader
=
create_reader
(
train_feed
,
cfg
.
max_iters
*
devices_num
,
FLAGS
.
dataset_dir
)
train_pyreader
.
decorate_sample_list_generator
(
train_reader
,
place
)
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_values
.
append
(
lr
)
train_fetch_list
=
[]
for
k
,
v
in
zip
(
train_keys
,
train_values
):
train_fetch_list
.
append
((
k
,
v
))
print
(
"train_fetch_list: {}"
.
format
(
train_fetch_list
))
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
eval_pyreader
,
test_feed_vars
=
create_feed
(
eval_feed
,
use_pyreader
=
False
)
fetches
=
model
.
eval
(
test_feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_dir
)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
(),
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
# print(eval_values)
eval_fetch_list
=
[]
for
k
,
v
in
zip
(
eval_keys
,
eval_values
):
eval_fetch_list
.
append
((
k
,
v
))
exe
.
run
(
startup_prog
)
start_iter
=
0
checkpoint
.
load_pretrain
(
exe
,
train_prog
,
cfg
.
pretrain_weights
)
def
eval_func
(
program
,
scope
):
#place = fluid.CPUPlace()
#exe = fluid.Executor(place)
results
=
eval_run
(
exe
,
program
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
best_box_ap_list
=
[]
resolution
=
None
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
False
,
FLAGS
.
output_eval
)
if
len
(
best_box_ap_list
)
==
0
:
best_box_ap_list
.
append
(
box_ap_stats
[
0
])
elif
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}"
.
format
(
best_box_ap_list
[
0
]))
return
best_box_ap_list
[
0
]
test_feed
=
[(
'image'
,
test_feed_vars
[
'image'
].
name
),
(
'im_size'
,
test_feed_vars
[
'im_size'
].
name
)]
com
=
Compressor
(
place
,
fluid
.
global_scope
(),
train_prog
,
train_reader
=
train_pyreader
,
train_feed_list
=
None
,
train_fetch_list
=
train_fetch_list
,
eval_program
=
eval_prog
,
eval_reader
=
eval_reader
,
eval_feed_list
=
test_feed
,
eval_func
=
{
'map'
:
eval_func
},
eval_fetch_list
=
[
eval_fetch_list
[
0
]],
train_optimizer
=
None
)
com
.
config
(
FLAGS
.
slim_file
)
com
.
run
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-s"
,
"--slim_file"
,
default
=
None
,
type
=
str
,
help
=
"Config file of PaddleSlim."
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
FLAGS
=
parser
.
parse_args
()
main
()
PaddleCV/PaddleDetection/slim/quantization/eval.py
0 → 100644
浏览文件 @
a72f988a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
import
datetime
from
collections
import
deque
import
sys
sys
.
path
.
append
(
"../../"
)
from
paddle.fluid.contrib.slim
import
Compressor
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid
import
core
from
paddle.fluid.contrib.slim.quantization
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization
import
QuantizationFreezePass
from
paddle.fluid.contrib.slim.quantization
import
ConvertToInt8Pass
from
paddle.fluid.contrib.slim.quantization
import
TransformForMobilePass
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
values
[
0
],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
else
:
raise
ValueError
(
"'architecture' not specified in config file."
)
merge_config
(
FLAGS
.
opt
)
if
'log_iter'
not
in
cfg
:
cfg
.
log_iter
=
20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
eval_pyreader
,
test_feed_vars
=
create_feed
(
eval_feed
,
use_pyreader
=
False
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_dir
)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
(),
place
)
assert
os
.
path
.
exists
(
FLAGS
.
model_path
)
infer_prog
,
feed_names
,
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
FLAGS
.
model_path
,
executor
=
exe
,
model_filename
=
'model'
,
params_filename
=
'params'
)
eval_keys
=
[
'bbox'
,
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_values
=
[
'multiclass_nms_0.tmp_0'
,
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_cls
=
[]
eval_values
[
0
]
=
fetch_targets
[
0
]
results
=
eval_run
(
exe
,
infer_prog
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
resolution
=
None
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
False
,
FLAGS
.
output_eval
)
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-m"
,
"--model_path"
,
default
=
None
,
type
=
str
,
help
=
"path of checkpoint"
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
FLAGS
=
parser
.
parse_args
()
main
()
PaddleCV/PaddleDetection/slim/quantization/freeze.py
0 → 100644
浏览文件 @
a72f988a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
numpy
as
np
import
datetime
from
collections
import
deque
import
sys
sys
.
path
.
append
(
"../../"
)
from
paddle.fluid.contrib.slim
import
Compressor
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid
import
core
from
paddle.fluid.contrib.slim.quantization
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization
import
QuantizationFreezePass
from
paddle.fluid.contrib.slim.quantization
import
ConvertToInt8Pass
from
paddle.fluid.contrib.slim.quantization
import
TransformForMobilePass
def
set_paddle_flags
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
if
os
.
environ
.
get
(
key
,
None
)
is
None
:
os
.
environ
[
key
]
=
str
(
value
)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags
(
FLAGS_eager_delete_tensor_gb
=
0
,
# enable GC to save memory
)
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.data_feed
import
create_reader
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.modeling.model_input
import
create_feed
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
eval_run
(
exe
,
compile_program
,
reader
,
keys
,
values
,
cls
,
test_feed
):
"""
Run evaluation program, return program outputs.
"""
iter_id
=
0
results
=
[]
images_num
=
0
start_time
=
time
.
time
()
has_bbox
=
'bbox'
in
keys
for
data
in
reader
():
data
=
test_feed
.
feed
(
data
)
feed_data
=
{
'image'
:
data
[
'image'
],
'im_size'
:
data
[
'im_size'
]}
outs
=
exe
.
run
(
compile_program
,
feed
=
feed_data
,
fetch_list
=
values
[
0
],
return_numpy
=
False
)
outs
.
append
(
data
[
'gt_box'
])
outs
.
append
(
data
[
'gt_label'
])
outs
.
append
(
data
[
'is_difficult'
])
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
keys
,
outs
)
}
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
iter_id
+=
1
images_num
+=
len
(
res
[
'bbox'
][
1
][
0
])
if
has_bbox
else
1
logger
.
info
(
'Test finish iter {}'
.
format
(
iter_id
))
end_time
=
time
.
time
()
fps
=
images_num
/
(
end_time
-
start_time
)
if
has_bbox
:
logger
.
info
(
'Total number of images: {}, inference time: {} fps.'
.
format
(
images_num
,
fps
))
else
:
logger
.
info
(
'Total iteration: {}, inference time: {} batch/s.'
.
format
(
images_num
,
fps
))
return
results
def
main
():
cfg
=
load_config
(
FLAGS
.
config
)
if
'architecture'
in
cfg
:
main_arch
=
cfg
.
architecture
else
:
raise
ValueError
(
"'architecture' not specified in config file."
)
merge_config
(
FLAGS
.
opt
)
if
'log_iter'
not
in
cfg
:
cfg
.
log_iter
=
20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
if
cfg
.
use_gpu
:
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
if
'eval_feed'
not
in
cfg
:
eval_feed
=
create
(
main_arch
+
'EvalFeed'
)
else
:
eval_feed
=
create
(
cfg
.
eval_feed
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
eval_pyreader
,
test_feed_vars
=
create_feed
(
eval_feed
,
use_pyreader
=
False
)
eval_reader
=
create_reader
(
eval_feed
,
args_path
=
FLAGS
.
dataset_dir
)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed
=
fluid
.
DataFeeder
(
test_feed_vars
.
values
(),
place
)
assert
os
.
path
.
exists
(
FLAGS
.
model_path
)
infer_prog
,
feed_names
,
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
FLAGS
.
model_path
,
executor
=
exe
,
model_filename
=
'__model__'
,
params_filename
=
'__params__'
)
eval_keys
=
[
'bbox'
,
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_values
=
[
'multiclass_nms_0.tmp_0'
,
'gt_box'
,
'gt_label'
,
'is_difficult'
]
eval_cls
=
[]
eval_values
[
0
]
=
fetch_targets
[
0
]
results
=
eval_run
(
exe
,
infer_prog
,
eval_reader
,
eval_keys
,
eval_values
,
eval_cls
,
test_data_feed
)
resolution
=
None
if
'mask'
in
results
[
0
]:
resolution
=
model
.
mask_head
.
resolution
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
False
,
FLAGS
.
output_eval
)
logger
.
info
(
"freeze the graph for inference"
)
test_graph
=
IrGraph
(
core
.
Graph
(
infer_prog
.
desc
),
for_test
=
True
)
freeze_pass
=
QuantizationFreezePass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
weight_quantize_type
=
FLAGS
.
weight_quant_type
)
freeze_pass
.
apply
(
test_graph
)
server_program
=
test_graph
.
to_program
()
fluid
.
io
.
save_inference_model
(
dirname
=
os
.
path
.
join
(
FLAGS
.
save_path
,
'float'
),
feeded_var_names
=
feed_names
,
target_vars
=
fetch_targets
,
executor
=
exe
,
main_program
=
server_program
,
model_filename
=
'model'
,
params_filename
=
'params'
)
logger
.
info
(
"convert the weights into int8 type"
)
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
fluid
.
global_scope
(),
place
=
place
)
convert_int8_pass
.
apply
(
test_graph
)
server_int8_program
=
test_graph
.
to_program
()
fluid
.
io
.
save_inference_model
(
dirname
=
os
.
path
.
join
(
FLAGS
.
save_path
,
'int8'
),
feeded_var_names
=
feed_names
,
target_vars
=
fetch_targets
,
executor
=
exe
,
main_program
=
server_int8_program
,
model_filename
=
'model'
,
params_filename
=
'params'
)
logger
.
info
(
"convert the freezed pass to paddle-lite execution"
)
mobile_pass
=
TransformForMobilePass
()
mobile_pass
.
apply
(
test_graph
)
mobile_program
=
test_graph
.
to_program
()
fluid
.
io
.
save_inference_model
(
dirname
=
os
.
path
.
join
(
FLAGS
.
save_path
,
'mobile'
),
feeded_var_names
=
feed_names
,
target_vars
=
fetch_targets
,
executor
=
exe
,
main_program
=
mobile_program
,
model_filename
=
'model'
,
params_filename
=
'params'
)
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-m"
,
"--model_path"
,
default
=
None
,
type
=
str
,
help
=
"path of checkpoint"
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
None
,
type
=
str
,
help
=
"Dataset path, same as DataFeed.dataset.dataset_dir"
)
parser
.
add_argument
(
"--weight_quant_type"
,
default
=
'abs_max'
,
type
=
str
,
help
=
"quantization type for weight"
)
parser
.
add_argument
(
"--save_path"
,
default
=
'./output'
,
type
=
str
,
help
=
"path to save quantization inference model"
)
FLAGS
=
parser
.
parse_args
()
main
()
PaddleCV/PaddleDetection/slim/quantization/yolov3_mobilenet_v1_slim.yaml
0 → 100644
浏览文件 @
a72f988a
version
:
1.0
strategies
:
quantization_strategy
:
class
:
'
QuantizationStrategy'
start_epoch
:
0
end_epoch
:
0
float_model_save_path
:
'
./output/yolov3/float'
mobile_model_save_path
:
'
./output/yolov3/mobile'
int8_model_save_path
:
'
./output/yolov3/int8'
weight_bits
:
8
activation_bits
:
8
weight_quantize_type
:
'
abs_max'
activation_quantize_type
:
'
moving_average_abs_max'
save_in_nodes
:
[
'
image'
,
'
im_size'
]
save_out_nodes
:
[
'
multiclass_nms_0.tmp_0'
]
compressor
:
epoch
:
1
checkpoint_path
:
'
./checkpoints/yolov3/'
strategies
:
-
quantization_strategy
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录