Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
5dcc4e19
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5dcc4e19
编写于
6月 28, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
6月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update yolov5s act demo (#1200)
上级
cfec7b34
变更
25
显示空白变更内容
内联
并排
Showing
25 changed file
with
527 addition
and
59 deletion
+527
-59
demo/auto_compression/detection/README.md
demo/auto_compression/detection/README.md
+3
-5
demo/auto_compression/detection/eval.py
demo/auto_compression/detection/eval.py
+6
-13
demo/auto_compression/detection/keypoint_utils.py
demo/auto_compression/detection/keypoint_utils.py
+8
-6
demo/auto_compression/detection/run.py
demo/auto_compression/detection/run.py
+6
-13
demo/auto_compression/detection/run_tinypose.py
demo/auto_compression/detection/run_tinypose.py
+22
-16
demo/auto_compression/pytorch_huggingface/README.md
demo/auto_compression/pytorch_huggingface/README.md
+4
-4
demo/auto_compression/pytorch_huggingface/configs/cola.yaml
demo/auto_compression/pytorch_huggingface/configs/cola.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/mnli.yaml
demo/auto_compression/pytorch_huggingface/configs/mnli.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/mrpc.yaml
demo/auto_compression/pytorch_huggingface/configs/mrpc.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/qnli.yaml
demo/auto_compression/pytorch_huggingface/configs/qnli.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/qqp.yaml
demo/auto_compression/pytorch_huggingface/configs/qqp.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/rte.yaml
demo/auto_compression/pytorch_huggingface/configs/rte.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/sst2.yaml
demo/auto_compression/pytorch_huggingface/configs/sst2.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/configs/stsb.yaml
demo/auto_compression/pytorch_huggingface/configs/stsb.yaml
+0
-0
demo/auto_compression/pytorch_huggingface/infer.py
demo/auto_compression/pytorch_huggingface/infer.py
+0
-0
demo/auto_compression/pytorch_huggingface/run.py
demo/auto_compression/pytorch_huggingface/run.py
+0
-0
demo/auto_compression/pytorch_huggingface/run.sh
demo/auto_compression/pytorch_huggingface/run.sh
+0
-0
demo/auto_compression/pytorch_yolov5/README.md
demo/auto_compression/pytorch_yolov5/README.md
+122
-0
demo/auto_compression/pytorch_yolov5/configs/yolov5_reader.yml
...auto_compression/pytorch_yolov5/configs/yolov5_reader.yml
+0
-0
demo/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml
...o_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml
+0
-0
demo/auto_compression/pytorch_yolov5/eval.py
demo/auto_compression/pytorch_yolov5/eval.py
+168
-0
demo/auto_compression/pytorch_yolov5/images/000000570688.jpg
demo/auto_compression/pytorch_yolov5/images/000000570688.jpg
+0
-0
demo/auto_compression/pytorch_yolov5/paddle_trt_infer.py
demo/auto_compression/pytorch_yolov5/paddle_trt_infer.py
+2
-2
demo/auto_compression/pytorch_yolov5/post_process.py
demo/auto_compression/pytorch_yolov5/post_process.py
+0
-0
demo/auto_compression/pytorch_yolov5/run.py
demo/auto_compression/pytorch_yolov5/run.py
+186
-0
未找到文件。
demo/auto_compression/detection/README.md
浏览文件 @
5dcc4e19
...
...
@@ -113,8 +113,6 @@ wget https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_co
tar
-xf
ppyoloe_crn_l_300e_coco.tar
```
**注意**
:TinyPose模型暂不支持精度测试。
#### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
...
...
@@ -128,14 +126,14 @@ python run.py --config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./outpu
#### 3.5 测试模型精度
使用
run
.py脚本得到模型的mAP:
使用
eval
.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml
```
**注意**
:
要测试的模型路径可以在配置文件中
`model_dir`
字段下进行修改。
**注意**
:
-
要测试的模型路径可以在配置文件中
`model_dir`
字段下进行修改。
## 4.预测部署
...
...
demo/auto_compression/detection/eval.py
浏览文件 @
5dcc4e19
...
...
@@ -22,8 +22,6 @@ from ppdet.core.workspace import create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
post_process
import
YOLOv5PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
...
@@ -108,11 +106,6 @@ def eval():
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
...
...
demo/auto_compression/detection/keypoint_utils.py
浏览文件 @
5dcc4e19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -15,15 +14,16 @@
import
logging
import
os
import
json
from
collections
import
defaultdict
,
OrderedDict
import
numpy
as
np
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
from
scipy.io
import
loadmat
,
savemat
import
cv2
from
paddleslim.common
import
get_logger
logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
def
get_affine_mat_kernel
(
h
,
w
,
s
,
inv
=
False
):
if
w
<
h
:
w_
=
s
...
...
@@ -231,6 +231,7 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
ious
[
n_d
]
=
np
.
sum
(
np
.
exp
(
-
e
))
/
e
.
shape
[
0
]
if
e
.
shape
[
0
]
!=
0
else
0.0
return
ious
def
oks_nms
(
kpts_db
,
thresh
,
sigmas
=
None
,
in_vis_thre
=
None
):
"""greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh
...
...
@@ -268,6 +269,7 @@ def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
return
keep
def
rescore
(
overlap
,
scores
,
thresh
,
type
=
'gaussian'
):
assert
overlap
.
shape
[
0
]
==
scores
.
shape
[
0
]
if
type
==
'linear'
:
...
...
@@ -406,10 +408,10 @@ class HRNetPostProcess(object):
return
coord
def
dark_postprocess
(
self
,
hm
,
coords
,
kernelsize
):
'''DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
'''
DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
'''
hm
=
self
.
gaussian_blur
(
hm
,
kernelsize
)
hm
=
np
.
maximum
(
hm
,
1e-10
)
hm
=
np
.
log
(
hm
)
...
...
@@ -419,7 +421,8 @@ class HRNetPostProcess(object):
return
coords
def
get_final_preds
(
self
,
heatmaps
,
center
,
scale
,
kernelsize
=
3
):
"""the highest heatvalue location with a quarter offset in the
"""
The highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
Args:
heatmaps (numpy.ndarray): The predicted heatmaps
...
...
@@ -465,4 +468,3 @@ class HRNetPostProcess(object):
maxvals
,
axis
=
1
)
]]
return
outputs
demo/auto_compression/detection/run.py
浏览文件 @
5dcc4e19
...
...
@@ -23,8 +23,6 @@ from ppdet.metrics import COCOMetric, VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
post_process
import
YOLOv5PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
...
@@ -104,11 +102,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
...
...
demo/auto_compression/detection/run_tinypose.py
浏览文件 @
5dcc4e19
...
...
@@ -27,6 +27,7 @@ from paddleslim.auto_compression import AutoCompression
from
paddleslim.quant
import
quant_post_static
from
keypoint_utils
import
HRNetPostProcess
,
transform_preds
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
...
...
@@ -69,6 +70,7 @@ def reader_wrapper(reader, input_list):
return
gen
def
flip_back
(
output_flipped
,
matched_parts
):
assert
output_flipped
.
ndim
==
4
,
\
'output_flipped should be [batch_size, num_joints, height, width]'
...
...
@@ -82,6 +84,7 @@ def flip_back(output_flipped, matched_parts):
return
output_flipped
def
eval
(
config
):
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
...
...
@@ -114,13 +117,15 @@ def eval(config):
return_numpy
=
False
)
output_flipped
=
np
.
array
(
output_flipped
[
0
])
flip_perm
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
flip_perm
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
output_flipped
=
flip_back
(
output_flipped
,
flip_perm
)
output_flipped
[:,
:,
:,
1
:]
=
copy
.
copy
(
output_flipped
)[:,
:,
:,
0
:
-
1
]
hrnet_outputs
=
(
np
.
array
(
outs
[
0
])
+
output_flipped
)
*
0.5
imshape
=
(
np
.
array
(
data
[
'im_shape'
])
)[:,
::
-
1
]
if
'im_shape'
in
data
else
None
center
=
np
.
array
(
data
[
'center'
])
if
'center'
in
data
else
np
.
round
(
imshape
/
2.
)
imshape
=
(
np
.
array
(
data
[
'im_shape'
]))[:,
::
-
1
]
if
'im_shape'
in
data
else
None
center
=
np
.
array
(
data
[
'center'
])
if
'center'
in
data
else
np
.
round
(
imshape
/
2.
)
scale
=
np
.
array
(
data
[
'scale'
])
if
'scale'
in
data
else
imshape
/
200.
outputs
=
post_process
(
hrnet_outputs
,
center
,
scale
)
outputs
=
{
'keypoint'
:
outputs
}
...
...
@@ -155,17 +160,18 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
return_numpy
=
False
)
output_flipped
=
np
.
array
(
output_flipped
[
0
])
flip_perm
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
flip_perm
=
[[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
],
[
13
,
14
],
[
15
,
16
]]
output_flipped
=
flip_back
(
output_flipped
,
flip_perm
)
output_flipped
[:,
:,
:,
1
:]
=
copy
.
copy
(
output_flipped
)[:,
:,
:,
0
:
-
1
]
hrnet_outputs
=
(
np
.
array
(
outs
[
0
])
+
output_flipped
)
*
0.5
imshape
=
(
np
.
array
(
data
[
'im_shape'
])
)[:,
::
-
1
]
if
'im_shape'
in
data
else
None
center
=
np
.
array
(
data
[
'center'
])
if
'center'
in
data
else
np
.
round
(
imshape
/
2.
)
imshape
=
(
np
.
array
(
data
[
'im_shape'
]))[:,
::
-
1
]
if
'im_shape'
in
data
else
None
center
=
np
.
array
(
data
[
'center'
])
if
'center'
in
data
else
np
.
round
(
imshape
/
2.
)
scale
=
np
.
array
(
data
[
'scale'
])
if
'scale'
in
data
else
imshape
/
200.
outputs
=
post_process
(
hrnet_outputs
,
center
,
scale
)
outputs
=
{
'keypoint'
:
outputs
}
metric
.
update
(
data_all
,
outputs
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
...
...
demo/auto_compression/pytorch
-
huggingface/README.md
→
demo/auto_compression/pytorch
_
huggingface/README.md
浏览文件 @
5dcc4e19
...
...
@@ -15,7 +15,7 @@
飞桨模型转换工具
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
支持将
```Caffe/TensorFlow/ONNX/PyTorch```
的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。
本示例将以
[
Py
torch
](
https://github.com/pytorch/pytorch
)
框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用
[
huggingface
](
https://github.com/huggingface/transformers
)
开源transformers库,将Pyt
orch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(
```Post-training quantization```
)。
本示例将以
[
Py
Torch
](
https://github.com/pytorch/pytorch
)
框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用
[
huggingface
](
https://github.com/huggingface/transformers
)
开源transformers库,将PyT
orch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和离线量化(
```Post-training quantization```
)。
...
...
@@ -87,7 +87,7 @@ pip install paddlenlp
#### 3.3 X2Paddle转换模型流程
**方式1: PyTorch2Paddle直接将Py
t
orch动态图模型转为Paddle静态图模型**
**方式1: PyTorch2Paddle直接将Py
T
orch动态图模型转为Paddle静态图模型**
```
shell
import torch
...
...
@@ -116,7 +116,7 @@ PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图
-
使用PaddleNLP的tokenizer时需要在模型保存的文件夹中加入
```model_config.json, special_tokens_map.json, tokenizer_config.json, vocab.txt```
这些文件。
更多Py
t
orch2Paddle示例可参考
[
PyTorch模型转换文档
](
https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md
)
。其他框架转换可参考
[
X2Paddle模型转换工具
](
https://github.com/PaddlePaddle/X2Paddle
)
更多Py
T
orch2Paddle示例可参考
[
PyTorch模型转换文档
](
https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/inference_model_convertor/pytorch2paddle.md
)
。其他框架转换可参考
[
X2Paddle模型转换工具
](
https://github.com/PaddlePaddle/X2Paddle
)
如想快速尝试运行实验,也可以直接下载已经转换好的模型,链接如下:
|
[
CoLA
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar
)
|
[
MRPC
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_mrpc.tar
)
|
[
QNLI
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qnli.tar
)
|
[
QQP
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_qqp.tar
)
|
[
RTE
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_rte.tar
)
|
[
SST2
](
https://paddle-slim-models.bj.bcebos.com/act/x2paddle_sst2.tar
)
|
...
...
@@ -126,7 +126,7 @@ wget https://paddle-slim-models.bj.bcebos.com/act/x2paddle_cola.tar
tar
xf x2paddle_cola.tar
```
**方式2: Onnx2Paddle将Py
t
orch动态图模型保存为Onnx格式后再转为Paddle静态图模型**
**方式2: Onnx2Paddle将Py
T
orch动态图模型保存为Onnx格式后再转为Paddle静态图模型**
PyTorch 导出 ONNX 动态图模型
...
...
demo/auto_compression/pytorch
-
huggingface/configs/cola.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/cola.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/mnli.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/mnli.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/mrpc.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/mrpc.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/qnli.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/qnli.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/qqp.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/qqp.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/rte.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/rte.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/sst2.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/sst2.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/configs/stsb.yaml
→
demo/auto_compression/pytorch
_
huggingface/configs/stsb.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/infer.py
→
demo/auto_compression/pytorch
_
huggingface/infer.py
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/run.py
→
demo/auto_compression/pytorch
_
huggingface/run.py
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch
-
huggingface/run.sh
→
demo/auto_compression/pytorch
_
huggingface/run.sh
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch_yolov5/README.md
0 → 100644
浏览文件 @
5dcc4e19
# 目标检测模型自动压缩示例
目录:
-
[
1.简介
](
#1简介
)
-
[
2.Benchmark
](
#2Benchmark
)
-
[
3.开始自动压缩
](
#自动压缩流程
)
-
[
3.1 环境准备
](
#31-准备环境
)
-
[
3.2 准备数据集
](
#32-准备数据集
)
-
[
3.3 准备预测模型
](
#33-准备预测模型
)
-
[
3.4 测试模型精度
](
#34-测试模型精度
)
-
[
3.5 自动压缩并产出模型
](
#35-自动压缩并产出模型
)
-
[
4.预测部署
](
#4预测部署
)
-
[
5.FAQ
](
5FAQ
)
## 1. 简介
飞桨模型转换工具
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
支持将
```Caffe/TensorFlow/ONNX/PyTorch```
的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,各种框架的推理模型可以很方便的使用PaddleSlim的自动化压缩功能。
本示例将以
[
ultralytics/yolov5
](
https://github.com/ultralytics/yolov5
)
目标检测模型为例,将PyTorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。
## 2.Benchmark
| 模型 | 策略 | 输入尺寸 | mAP
<sup>
val
<br>
0.5:0.95 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP16
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640
*
640 | 37.4 | 7.8ms | 4.3ms | - | - |
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar
)
|
| YOLOv5s | 量化+蒸馏 | 640
*
640 | 36.5 | - | - | 3.4ms |
[
config
](
https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar
)
|
说明:
-
mAP的指标均在COCO val2017数据集中评测得到。
-
YOLOv5s模型在Tesla T4的GPU环境下测试,并且开启TensorRT,测试脚本是
[
benchmark demo
](
./paddle_trt_infer.py
)
## 3. 自动压缩流程
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim develop版本
-
PaddleDet >= 2.4
-
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
>= 1.3.6
-
opencv-python
(1)安装paddlepaddle:
```
shell
# CPU
pip
install
paddlepaddle
# GPU
pip
install
paddlepaddle-gpu
```
(2)安装paddleslim:
```
shell
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py
install
```
(3)安装paddledet:
```
shell
pip
install
paddledet
```
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```
shell
pip
install
x2paddle
```
#### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考
[
PaddleDetection数据准备文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md
)
来准备数据。
#### 3.3 准备预测模型
(1)准备ONNX模型:
可通过
[
ultralytics/yolov5
](
https://github.com/ultralytics/yolov5
)
官方的
[
导出教程
](
https://github.com/ultralytics/yolov5/issues/251
)
来准备ONNX模型。
```
python export.py --weights yolov5s.pt --include onnx
```
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov5_inference_model
```
即可得到YOLOv5s模型的预测模型(
`model.pdmodel`
和
`model.pdiparams`
)。如想快速体验,可直接下载上方表格中YOLOv5s的
[
Base预测模型
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar
)
。
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
#### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
```
# 单卡
export CUDA_VISIBLE_DEVICES=0
# 多卡
# export CUDA_VISIBLE_DEVICES=0,1,2,3
python run.py --config_path=./configs/yolov5s_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/yolov5s_qat_dis.yaml
```
**注意**
:要测试的模型路径需要在配置文件中
`model_dir`
字段下进行修改指定。
## 4.预测部署
-
Paddle-TensorRT部署:
使用
[
paddle_trt_infer.py
](
./paddle_trt_infer.py
)
进行部署:
```
shell
python paddle_trt_infer.py
--model_path
=
output
--image_file
=
images/000000570688.jpg
--benchmark
=
True
--run_mode
=
trt_int8
```
## 5.FAQ
demo/auto_compression/
detection
/configs/yolov5_reader.yml
→
demo/auto_compression/
pytorch_yolov5
/configs/yolov5_reader.yml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/
detection
/configs/yolov5s_qat_dis.yaml
→
demo/auto_compression/
pytorch_yolov5
/configs/yolov5s_qat_dis.yaml
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch_yolov5/eval.py
0 → 100644
浏览文件 @
5dcc4e19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
post_process
import
YOLOv5PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
return
parser
def
print_arguments
(
args
):
print
(
'----------- Running Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------'
)
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval
():
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
fluid
.
io
.
load_inference_model
(
global_config
[
"model_dir"
],
exe
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
])
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
global_config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
val_program
,
feed
=
data_input
,
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
metric
.
reset
()
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
eval
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
demo/auto_compression/pytorch_yolov5/images/000000570688.jpg
0 → 100644
浏览文件 @
5dcc4e19
135.1 KB
demo/auto_compression/
detection/
infer.py
→
demo/auto_compression/
pytorch_yolov5/paddle_trt_
infer.py
浏览文件 @
5dcc4e19
...
...
@@ -303,8 +303,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'
C
PU'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is
C
PU"
default
=
'
G
PU'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is
G
PU"
)
parser
.
add_argument
(
'--img_shape'
,
type
=
int
,
default
=
640
,
help
=
"input_size"
)
args
=
parser
.
parse_args
()
...
...
demo/auto_compression/
detection
/post_process.py
→
demo/auto_compression/
pytorch_yolov5
/post_process.py
浏览文件 @
5dcc4e19
文件已移动
demo/auto_compression/pytorch_yolov5/run.py
0 → 100644
浏览文件 @
5dcc4e19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
post_process
import
YOLOv5PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--eval'
,
type
=
bool
,
default
=
False
,
help
=
"whether to run evaluation."
)
return
parser
def
print_arguments
(
args
):
print
(
'----------- Running Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------'
)
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
compiled_test_program
,
feed
=
data_input
,
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
map_res
=
metric
.
get_results
()
metric
.
reset
()
return
map_res
[
'bbox'
][
0
]
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
if
'Evaluation'
in
global_config
.
keys
()
and
global_config
[
'Evaluation'
]:
eval_func
=
eval_function
else
:
eval_func
=
None
ac
=
AutoCompression
(
model_dir
=
global_config
[
"model_dir"
],
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
save_dir
=
FLAGS
.
save_dir
,
config
=
all_config
,
train_dataloader
=
train_loader
,
eval_callback
=
eval_func
)
ac
.
compress
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录