Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
744b4d0e
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
744b4d0e
编写于
8月 12, 2022
作者:
C
Chang Xu
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Full Quantization/PicoDet Demo (#1335)
上级
28b179f8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
520 addition
and
0 deletion
+520
-0
example/full_quantization/detection/README.md
example/full_quantization/detection/README.md
+120
-0
example/full_quantization/detection/configs/picodet_reader.yml
...le/full_quantization/detection/configs/picodet_reader.yml
+32
-0
example/full_quantization/detection/configs/picodet_s_qat_dis.yaml
...ull_quantization/detection/configs/picodet_s_qat_dis.yaml
+34
-0
example/full_quantization/detection/eval.py
example/full_quantization/detection/eval.py
+157
-0
example/full_quantization/detection/run.py
example/full_quantization/detection/run.py
+177
-0
未找到文件。
example/full_quantization/detection/README.md
0 → 100644
浏览文件 @
744b4d0e
# 目标检测模型全量化示例
目录:
-
[
1.简介
](
#1简介
)
-
[
2.Benchmark
](
#2Benchmark
)
-
[
3.开始全量化
](
#全量化流程
)
-
[
3.1 环境准备
](
#31-准备环境
)
-
[
3.2 准备数据集
](
#32-准备数据集
)
-
[
3.3 准备预测模型
](
#33-准备预测模型
)
-
[
3.4 测试模型精度
](
#34-测试模型精度
)
-
[
3.5 全量化并产出模型
](
#35-全量化并产出模型
)
-
[
4.预测部署
](
#4预测部署
)
-
[
5.FAQ
](
5FAQ
)
## 1. 简介
本示例将以目标检测模型PicoDet为例,介绍如何使用PaddleDetection中Inference部署模型进行全量化。本示例使用的全量化策略为全量化加蒸馏。
## 2.Benchmark
### PicoDet-S-NPU
| 模型 | Base mAP | 离线量化mAP | ACT量化mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 量化模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
| PicoDet-S-NPU | 30.1 | - | 44.2 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
|
| PicoDet-S-NPU | 29.7 | - | 43.9 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/detection/configs/picodet_s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar
)
|
-
mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
## 3. 全量化流程
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim >= 2.3
-
PaddleDet >= 2.4
-
opencv-python
安装paddlepaddle:
```
shell
# CPU
pip
install
paddlepaddle
# GPU
pip
install
paddlepaddle-gpu
```
安装paddleslim:
```
shell
pip
install
paddleslim
```
安装paddledet:
```
shell
pip
install
paddledet
```
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
#### 3.2 准备数据集
本案例默认以COCO数据进行全量化实验,如果自定义COCO数据,或者其他格式数据,请参考
[
PaddleDetection数据准备文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md
)
来准备数据。
如果数据集为非COCO格式数据,请修改
[
configs
](
./configs
)
中reader配置文件中的Dataset字段。
以PicoDet-S-NPU模型为例,如果已经准备好数据集,请直接修改[./configs/picodet_reader.yml]中
`EvalDataset`
的
`dataset_dir`
字段为自己数据集路径即可。
#### 3.3 准备预测模型
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
注:其他像
`__model__`
和
`__params__`
分别对应
`model.pdmodel`
和
`model.pdiparams`
文件。
根据
[
PaddleDetection文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA
)
导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例:
-
下载代码
```
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
-
导出预测模型
PicoDet-S-NPU模型,包含NMS:如快速体验,可直接下载
[
PicoDet-S-NPU导出模型
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
```
shell
python tools/export_model.py
\
-c
configs/picodet/picodet_s_416_coco_npu.yml
\
-o
weights
=
https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams
\
trt
=
True
\
```
#### 3.4 全量化并产出模型
全量化示例通过run.py脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行全量化。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
-
单卡训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/'
```
-
多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/picodet_s_qat_dis.yaml
```
**注意**
:
-
要测试的模型路径可以在配置文件中
`model_dir`
字段下进行修改。
## 4.预测部署
## 5.FAQ
example/full_quantization/detection/configs/picodet_reader.yml
0 → 100644
浏览文件 @
744b4d0e
metric
:
COCO
num_classes
:
80
# Datset configuration
TrainDataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
/paddle/dataset/coco/
EvalDataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
/paddle/dataset/coco/
worker_num
:
6
eval_height
:
&eval_height
416
eval_width
:
&eval_width
416
eval_size
:
&eval_size
[
*eval_height
,
*eval_width
]
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
2
,
target_size
:
*eval_size
,
keep_ratio
:
False
}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
8
shuffle
:
false
example/full_quantization/detection/configs/picodet_s_qat_dis.yaml
0 → 100644
浏览文件 @
744b4d0e
Global
:
reader_config
:
./configs/picodet_reader.yml
input_list
:
[
'
image'
,
'
scale_factor'
]
Evaluation
:
True
model_dir
:
./picodet_s_416_coco_npu/
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
Distillation
:
alpha
:
1.0
loss
:
l2
Quantization
:
use_pact
:
true
activation_quantize_type
:
'
moving_average_abs_max'
weight_bits
:
8
activation_bits
:
8
quantize_op_types
:
-
conv2d
-
depthwise_conv2d
TrainConfig
:
train_iter
:
8000
eval_iter
:
1000
learning_rate
:
type
:
CosineAnnealingDecay
learning_rate
:
0.00001
T_max
:
8000
optimizer_builder
:
optimizer
:
type
:
SGD
weight_decay
:
4.0e-05
example/full_quantization/detection/eval.py
0 → 100644
浏览文件 @
744b4d0e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
,
KeyPointTopDownCOCOEval
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval
():
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
static
.
load_inference_model
(
global_config
[
"model_dir"
].
rstrip
(
'/'
),
exe
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
])
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
global_config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
val_program
,
feed
=
data_input
,
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
metric
.
reset
()
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
elif
reader_cfg
[
'metric'
]
==
'KeyPointTopDownCOCOEval'
:
anno_file
=
dataset
.
get_anno
()
metric
=
KeyPointTopDownCOCOEval
(
anno_file
,
len
(
dataset
),
17
,
'output_eval'
)
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
eval
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
example/full_quantization/detection/run.py
0 → 100644
浏览文件 @
744b4d0e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
,
KeyPointTopDownCOCOEval
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
compiled_test_program
,
feed
=
data_input
,
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
map_res
=
metric
.
get_results
()
metric
.
reset
()
map_key
=
'keypoint'
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'keypoint'
else
'bbox'
return
map_res
[
map_key
][
0
]
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
if
'Evaluation'
in
global_config
.
keys
()
and
global_config
[
'Evaluation'
]
and
paddle
.
distributed
.
get_rank
()
==
0
:
eval_func
=
eval_function
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
_eval_batch_sampler
=
paddle
.
io
.
BatchSampler
(
dataset
,
batch_size
=
reader_cfg
[
'EvalReader'
][
'batch_size'
])
val_loader
=
create
(
'EvalReader'
)(
dataset
,
reader_cfg
[
'worker_num'
],
batch_sampler
=
_eval_batch_sampler
,
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
elif
reader_cfg
[
'metric'
]
==
'KeyPointTopDownCOCOEval'
:
anno_file
=
dataset
.
get_anno
()
metric
=
KeyPointTopDownCOCOEval
(
anno_file
,
len
(
dataset
),
17
,
'output_eval'
)
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
else
:
eval_func
=
None
ac
=
AutoCompression
(
model_dir
=
global_config
[
"model_dir"
],
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
save_dir
=
FLAGS
.
save_dir
,
config
=
all_config
,
train_dataloader
=
train_loader
,
eval_callback
=
eval_func
)
ac
.
compress
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录