Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d3d78272
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d3d78272
编写于
7月 07, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add yolov6s act example (#1257) (#1268)
上级
860d55f6
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
1599 addition
and
1 deletion
+1599
-1
README.md
README.md
+1
-1
example/auto_compression/pytorch_yolov6/README.md
example/auto_compression/pytorch_yolov6/README.md
+138
-0
example/auto_compression/pytorch_yolov6/configs/yolov6_reader.yml
...auto_compression/pytorch_yolov6/configs/yolov6_reader.yml
+27
-0
example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml
...o_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml
+31
-0
example/auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt
.../auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt
+263
-0
example/auto_compression/pytorch_yolov6/cpp_infer/README.md
example/auto_compression/pytorch_yolov6/cpp_infer/README.md
+50
-0
example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh
example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh
+37
-0
example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc
example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc
+116
-0
example/auto_compression/pytorch_yolov6/eval.py
example/auto_compression/pytorch_yolov6/eval.py
+159
-0
example/auto_compression/pytorch_yolov6/images/000000570688.jpg
...e/auto_compression/pytorch_yolov6/images/000000570688.jpg
+0
-0
example/auto_compression/pytorch_yolov6/paddle_trt_infer.py
example/auto_compression/pytorch_yolov6/paddle_trt_infer.py
+322
-0
example/auto_compression/pytorch_yolov6/post_process.py
example/auto_compression/pytorch_yolov6/post_process.py
+173
-0
example/auto_compression/pytorch_yolov6/post_quant.py
example/auto_compression/pytorch_yolov6/post_quant.py
+101
-0
example/auto_compression/pytorch_yolov6/run.py
example/auto_compression/pytorch_yolov6/run.py
+181
-0
未找到文件。
README.md
浏览文件 @
d3d78272
...
@@ -20,7 +20,7 @@ PaddleSlim是一个专注于深度学习模型压缩的工具库,提供**低
...
@@ -20,7 +20,7 @@ PaddleSlim是一个专注于深度学习模型压缩的工具库,提供**低
- 支持代码无感知压缩:用户只需提供推理模型文件和数据,既可进行离线量化(PTQ)、量化训练(QAT)、稀疏训练等压缩任务。
- 支持代码无感知压缩:用户只需提供推理模型文件和数据,既可进行离线量化(PTQ)、量化训练(QAT)、稀疏训练等压缩任务。
- 支持自动策略选择,根据任务特点和部署环境特性:自动搜索合适的离线量化方法,自动搜索最佳的压缩策略组合方式。
- 支持自动策略选择,根据任务特点和部署环境特性:自动搜索合适的离线量化方法,自动搜索最佳的压缩策略组合方式。
- 发布[自然语言处理](example/auto_compression/nlp)、[图像语义分割](example/auto_compression/semantic_segmentation)、[图像目标检测](example/auto_compression/detection)三个方向的自动化压缩示例。
- 发布[自然语言处理](example/auto_compression/nlp)、[图像语义分割](example/auto_compression/semantic_segmentation)、[图像目标检测](example/auto_compression/detection)三个方向的自动化压缩示例。
- 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[
HuggingFace](example/auto_compression/pytorch_huggingface)
[MobileNet](example/auto_compression/tensorflow_mobilenet)。
- 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[
YOLOv6](example/auto_compression/pytorch_yolov6)、[HuggingFace](example/auto_compression/pytorch_huggingface)、
[MobileNet](example/auto_compression/tensorflow_mobilenet)。
-
升级量化功能
-
升级量化功能
...
...
example/auto_compression/pytorch_yolov6/README.md
0 → 100644
浏览文件 @
d3d78272
# YOLOv6自动压缩示例
目录:
-
[
1.简介
](
#1简介
)
-
[
2.Benchmark
](
#2Benchmark
)
-
[
3.开始自动压缩
](
#自动压缩流程
)
-
[
3.1 环境准备
](
#31-准备环境
)
-
[
3.2 准备数据集
](
#32-准备数据集
)
-
[
3.3 准备预测模型
](
#33-准备预测模型
)
-
[
3.4 测试模型精度
](
#34-测试模型精度
)
-
[
3.5 自动压缩并产出模型
](
#35-自动压缩并产出模型
)
-
[
4.预测部署
](
#4预测部署
)
-
[
5.FAQ
](
5FAQ
)
## 1. 简介
飞桨模型转换工具
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
支持将
```Caffe/TensorFlow/ONNX/PyTorch```
的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,各种框架的推理模型可以很方便的使用PaddleSlim的自动化压缩功能。
本示例将以
[
meituan/YOLOv6
](
https://github.com/meituan/YOLOv6
)
目标检测模型为例,将PyTorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。
## 2.Benchmark
| 模型 | 策略 | 输入尺寸 | mAP
<sup>
val
<br>
0.5:0.95 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP16
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv6s | Base模型 | 640
*
640 | 42.4 | 9.06ms | 2.90ms | - | - |
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar
)
|
| YOLOv6s | KL离线量化 | 640
*
640 | 30.3 | - | - | 1.83ms | - | - |
| YOLOv6s | 量化蒸馏训练 | 640
*640 | **41.3** | - | - | **1.83ms*
*
|
[
config
](
./configs/yolov6s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar
)
|
说明:
-
mAP的指标均在COCO val2017数据集中评测得到。
-
YOLOv6s模型在Tesla T4的GPU环境下开启TensorRT 8.4.1,batch_size=1, 测试脚本是
[
cpp_infer
](
./cpp_infer
)
。
## 3. 自动压缩流程
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim > 2.3版本
-
PaddleDet >= 2.4
-
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
>= 1.3.6
-
opencv-python
(1)安装paddlepaddle:
```
shell
# CPU
pip
install
paddlepaddle
# GPU
pip
install
paddlepaddle-gpu
```
(2)安装paddleslim:
```
shell
pip
install
paddleslim
```
(3)安装paddledet:
```
shell
pip
install
paddledet
```
注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```
shell
pip
install
x2paddle sympy onnx
```
#### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考
[
PaddleDetection数据准备文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md
)
来准备数据。
如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中
`EvalDataset`
的
`dataset_dir`
字段为自己数据集路径即可。
#### 3.3 准备预测模型
(1)准备ONNX模型:
可通过
[
meituan/YOLOv6
](
https://github.com/meituan/YOLOv6
)
官方的
[
导出教程
](
https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md
)
来准备ONNX模型。也可以下载已经准备好的
[
yolov6s.onnx
](
https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx
)
。
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov6s_infer
```
即可得到YOLOv6s模型的预测模型(
`model.pdmodel`
和
`model.pdiparams`
)。如想快速体验,可直接下载上方表格中YOLOv6s的
[
Paddle预测模型
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar
)
。
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
#### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
-
单卡训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/yolov6s_qat_dis.yaml --save_dir='./output/'
```
-
多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/yolov6s_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
修改
[
yolov6s_qat_dis.yaml
](
./configs/yolov6s_qat_dis.yaml
)
中
`model_dir`
字段为模型存储路径,然后使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/yolov6s_qat_dis.yaml
```
## 4.预测部署
#### Paddle-TensorRT C++部署
进入
[
cpp_infer
](
./cpp_infer
)
文件夹内,请按照
[
C++ TensorRT Benchmark测试教程
](
./cpp_infer/README.md
)
进行准备环境及编译,然后开始测试:
```
shell
# 编译
bash complie.sh
# 执行
./build/trt_run
--model_file
yolov6s_quant/model.pdmodel
--params_file
yolov6s_quant/model.pdiparams
--run_mode
=
trt_int8
```
#### Paddle-TensorRT Python部署:
首先安装带有TensorRT的
[
Paddle安装包
](
https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python
)
。
然后使用
[
paddle_trt_infer.py
](
./paddle_trt_infer.py
)
进行部署:
```
shell
python paddle_trt_infer.py
--model_path
=
output
--image_file
=
images/000000570688.jpg
--benchmark
=
True
--run_mode
=
trt_int8
```
## 5.FAQ
example/auto_compression/pytorch_yolov6/configs/yolov6_reader.yml
0 → 100644
浏览文件 @
d3d78272
metric
:
COCO
num_classes
:
80
# Datset configuration
TrainDataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco/
EvalDataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco/
worker_num
:
0
# preprocess reader in test
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
640
,
640
],
keep_ratio
:
True
}
-
Pad
:
{
size
:
[
640
,
640
],
fill_value
:
[
114.
,
114.
,
114.
]}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
Permute
:
{}
batch_size
:
1
example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml
0 → 100644
浏览文件 @
d3d78272
Global
:
reader_config
:
configs/yolov6_reader.yml
input_list
:
{
'
image'
:
'
x2paddle_image_arrays'
}
Evaluation
:
True
arch
:
'
YOLOv6'
model_dir
:
./yolov6s_infer
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
Distillation
:
alpha
:
1.0
loss
:
soft_label
Quantization
:
activation_quantize_type
:
'
moving_average_abs_max'
quantize_op_types
:
-
conv2d
-
depthwise_conv2d
TrainConfig
:
train_iter
:
8000
eval_iter
:
1000
learning_rate
:
type
:
CosineAnnealingDecay
learning_rate
:
0.00003
T_max
:
8000
optimizer_builder
:
optimizer
:
type
:
SGD
weight_decay
:
0.00004
example/auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt
0 → 100644
浏览文件 @
d3d78272
cmake_minimum_required
(
VERSION 3.0
)
project
(
cpp_inference_demo CXX C
)
option
(
WITH_MKL
"Compile demo with MKL/OpenBlas support, default use MKL."
ON
)
option
(
WITH_GPU
"Compile demo with GPU/CPU, default use CPU."
OFF
)
option
(
WITH_STATIC_LIB
"Compile demo with static/shared library, default use static."
ON
)
option
(
USE_TENSORRT
"Compile demo with TensorRT."
OFF
)
option
(
WITH_ROCM
"Compile demo with rocm."
OFF
)
option
(
WITH_ONNXRUNTIME
"Compile demo with ONNXRuntime"
OFF
)
option
(
WITH_ARM
"Compile demo with ARM"
OFF
)
option
(
WITH_MIPS
"Compile demo with MIPS"
OFF
)
option
(
WITH_SW
"Compile demo with SW"
OFF
)
option
(
WITH_XPU
"Compile demow ith xpu"
OFF
)
option
(
WITH_NPU
"Compile demow ith npu"
OFF
)
if
(
NOT WITH_STATIC_LIB
)
add_definitions
(
"-DPADDLE_WITH_SHARED_LIB"
)
else
()
# PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode.
# Set it to empty in static library mode to avoid compilation issues.
add_definitions
(
"/DPD_INFER_DECL="
)
endif
()
macro
(
safe_set_static_flag
)
foreach
(
flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
)
if
(
${
flag_var
}
MATCHES
"/MD"
)
string
(
REGEX REPLACE
"/MD"
"/MT"
${
flag_var
}
"
${${
flag_var
}}
"
)
endif
(
${
flag_var
}
MATCHES
"/MD"
)
endforeach
(
flag_var
)
endmacro
()
if
(
NOT DEFINED PADDLE_LIB
)
message
(
FATAL_ERROR
"please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib"
)
endif
()
if
(
NOT DEFINED DEMO_NAME
)
message
(
FATAL_ERROR
"please set DEMO_NAME with -DDEMO_NAME=demo_name"
)
endif
()
include_directories
(
"
${
PADDLE_LIB
}
/"
)
set
(
PADDLE_LIB_THIRD_PARTY_PATH
"
${
PADDLE_LIB
}
/third_party/install/"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
protobuf/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
glog/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
gflags/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xxhash/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
cryptopp/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/include"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
protobuf/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
glog/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
gflags/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xxhash/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
cryptopp/lib"
)
link_directories
(
"
${
PADDLE_LIB
}
/paddle/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/lib"
)
if
(
WIN32
)
add_definitions
(
"/DGOOGLE_GLOG_DLL_DECL="
)
option
(
MSVC_STATIC_CRT
"use static C Runtime library by default"
ON
)
if
(
MSVC_STATIC_CRT
)
if
(
WITH_MKL
)
set
(
FLAG_OPENMP
"/openmp"
)
endif
()
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
/bigobj /MTd
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_C_FLAGS_RELEASE
"
${
CMAKE_C_FLAGS_RELEASE
}
/bigobj /MT
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
/bigobj /MTd
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_CXX_FLAGS_RELEASE
"
${
CMAKE_CXX_FLAGS_RELEASE
}
/bigobj /MT
${
FLAG_OPENMP
}
"
)
safe_set_static_flag
()
if
(
WITH_STATIC_LIB
)
add_definitions
(
-DSTATIC_LIB
)
endif
()
endif
()
else
()
if
(
WITH_MKL
)
set
(
FLAG_OPENMP
"-fopenmp"
)
endif
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11
${
FLAG_OPENMP
}
"
)
endif
()
if
(
WITH_GPU
)
if
(
NOT WIN32
)
include_directories
(
"/usr/local/cuda/include"
)
if
(
CUDA_LIB STREQUAL
""
)
set
(
CUDA_LIB
"/usr/local/cuda/lib64/"
CACHE STRING
"CUDA Library"
)
endif
()
else
()
include_directories
(
"C:
\\
Program
\
Files
\\
NVIDIA GPU Computing Toolkit
\\
CUDA
\\
v8.0
\\
include"
)
if
(
CUDA_LIB STREQUAL
""
)
set
(
CUDA_LIB
"C:
\\
Program
\
Files
\\
NVIDIA GPU Computing Toolkit
\\
CUDA
\\
v8.0
\\
lib
\\
x64"
)
endif
()
endif
(
NOT WIN32
)
endif
()
if
(
USE_TENSORRT AND WITH_GPU
)
set
(
TENSORRT_ROOT
""
CACHE STRING
"The root directory of TensorRT library"
)
if
(
"
${
TENSORRT_ROOT
}
"
STREQUAL
""
)
message
(
FATAL_ERROR
"The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH "
)
endif
()
set
(
TENSORRT_INCLUDE_DIR
${
TENSORRT_ROOT
}
/include
)
set
(
TENSORRT_LIB_DIR
${
TENSORRT_ROOT
}
/lib
)
file
(
READ
${
TENSORRT_INCLUDE_DIR
}
/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS
)
string
(
REGEX MATCH
"define NV_TENSORRT_MAJOR +([0-9]+)"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_VERSION_FILE_CONTENTS
}
"
)
if
(
"
${
TENSORRT_MAJOR_VERSION
}
"
STREQUAL
""
)
file
(
READ
${
TENSORRT_INCLUDE_DIR
}
/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS
)
string
(
REGEX MATCH
"define NV_TENSORRT_MAJOR +([0-9]+)"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_VERSION_FILE_CONTENTS
}
"
)
endif
()
if
(
"
${
TENSORRT_MAJOR_VERSION
}
"
STREQUAL
""
)
message
(
SEND_ERROR
"Failed to detect TensorRT version."
)
endif
()
string
(
REGEX REPLACE
"define NV_TENSORRT_MAJOR +([0-9]+)"
"
\\
1"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_MAJOR_VERSION
}
"
)
message
(
STATUS
"Current TensorRT header is
${
TENSORRT_INCLUDE_DIR
}
/NvInfer.h. "
"Current TensorRT version is v
${
TENSORRT_MAJOR_VERSION
}
. "
)
include_directories
(
"
${
TENSORRT_INCLUDE_DIR
}
"
)
link_directories
(
"
${
TENSORRT_LIB_DIR
}
"
)
endif
()
if
(
WITH_MKL
)
set
(
MATH_LIB_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mklml"
)
include_directories
(
"
${
MATH_LIB_PATH
}
/include"
)
if
(
WIN32
)
set
(
MATH_LIB
${
MATH_LIB_PATH
}
/lib/mklml
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
${
MATH_LIB_PATH
}
/lib/libiomp5md
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
MATH_LIB
${
MATH_LIB_PATH
}
/lib/libmklml_intel
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
MATH_LIB_PATH
}
/lib/libiomp5
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
set
(
MKLDNN_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mkldnn"
)
if
(
EXISTS
${
MKLDNN_PATH
}
)
include_directories
(
"
${
MKLDNN_PATH
}
/include"
)
if
(
WIN32
)
set
(
MKLDNN_LIB
${
MKLDNN_PATH
}
/lib/mkldnn.lib
)
else
(
WIN32
)
set
(
MKLDNN_LIB
${
MKLDNN_PATH
}
/lib/libmkldnn.so.0
)
endif
(
WIN32
)
endif
()
elseif
((
NOT WITH_MIPS
)
AND
(
NOT WITH_SW
))
set
(
OPENBLAS_LIB_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
openblas"
)
include_directories
(
"
${
OPENBLAS_LIB_PATH
}
/include/openblas"
)
if
(
WIN32
)
set
(
MATH_LIB
${
OPENBLAS_LIB_PATH
}
/lib/openblas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
MATH_LIB
${
OPENBLAS_LIB_PATH
}
/lib/libopenblas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_STATIC_LIB
)
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/libpaddle_inference
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
if
(
WIN32
)
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/paddle_inference
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/libpaddle_inference
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_ONNXRUNTIME
)
if
(
WIN32
)
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/onnxruntime.lib paddle2onnx
)
elseif
(
APPLE
)
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx
)
else
()
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx
)
endif
()
endif
()
if
(
NOT WIN32
)
set
(
EXTERNAL_LIB
"-lrt -ldl -lpthread"
)
set
(
DEPS
${
DEPS
}
${
MATH_LIB
}
${
MKLDNN_LIB
}
glog gflags protobuf xxhash cryptopp
${
EXTERNAL_LIB
}
)
else
()
set
(
DEPS
${
DEPS
}
${
MATH_LIB
}
${
MKLDNN_LIB
}
glog gflags_static libprotobuf xxhash cryptopp-static
${
EXTERNAL_LIB
}
)
set
(
DEPS
${
DEPS
}
shlwapi.lib
)
endif
(
NOT WIN32
)
if
(
WITH_GPU
)
if
(
NOT WIN32
)
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer_plugin
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/libcudart
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
else
()
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/myelin64_1
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudart
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cublas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudnn
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_ROCM AND NOT WIN32
)
set
(
DEPS
${
DEPS
}
${
ROCM_LIB
}
/libamdhip64
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
if
(
WITH_XPU AND NOT WIN32
)
set
(
XPU_INSTALL_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xpu"
)
set
(
DEPS
${
DEPS
}
${
XPU_INSTALL_PATH
}
/lib/libxpuapi
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
XPU_INSTALL_PATH
}
/lib/libxpurt
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
if
(
WITH_NPU AND NOT WIN32
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libgraph
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libge_runner
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libascendcl
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libascendcl
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
add_executable
(
${
DEMO_NAME
}
${
DEMO_NAME
}
.cc
)
target_link_libraries
(
${
DEMO_NAME
}
${
DEPS
}
)
if
(
WIN32
)
if
(
USE_TENSORRT
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/myelin64_1
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
endif
()
if
(
WITH_MKL
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MATH_LIB_PATH
}
/lib/mklml.dll
${
CMAKE_BINARY_DIR
}
/Release
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MATH_LIB_PATH
}
/lib/libiomp5md.dll
${
CMAKE_BINARY_DIR
}
/Release
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MKLDNN_PATH
}
/lib/mkldnn.dll
${
CMAKE_BINARY_DIR
}
/Release
)
else
()
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
OPENBLAS_LIB_PATH
}
/lib/openblas.dll
${
CMAKE_BINARY_DIR
}
/Release
)
endif
()
if
(
WITH_ONNXRUNTIME
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/onnxruntime.dll
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/lib/paddle2onnx.dll
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
if
(
NOT WITH_STATIC_LIB
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
"
${
PADDLE_LIB
}
/paddle/lib/paddle_inference.dll"
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
endif
()
example/auto_compression/pytorch_yolov6/cpp_infer/README.md
0 → 100644
浏览文件 @
d3d78272
# YOLOv6 TensorRT Benchmark测试(Linux)
## 环境准备
-
CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。
-
TensorRT:可通过NVIDIA官网下载
[
TensorRT 8.4.1.5
](
https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz
)
或其他版本安装包。
-
Paddle Inference C++预测库:编译develop版本请参考
[
编译文档
](
https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html
)
。编译完成后,会在build目录下生成
`paddle_inference_install_dir`
文件夹,这个就是我们需要的C++预测库文件。
## 编译可执行程序
-
(1)修改
`compile.sh`
中依赖库路径,主要是以下内容:
```
shell
# Paddle Inference预测库路径
LIB_DIR
=
/root/auto_compress/Paddle/build/paddle_inference_install_dir/
# CUDNN路径
CUDNN_LIB
=
/usr/lib/x86_64-linux-gnu/
# CUDA路径
CUDA_LIB
=
/usr/local/cuda/lib64
# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹
TENSORRT_ROOT
=
/root/auto_compress/trt/trt8.4/
```
## 测试
-
FP32
```
./build/trt_run --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp32
```
-
FP16
```
./build/trt_run --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp16
```
-
INT8
```
./build/trt_run --model_file yolov6s_quant/model.pdmodel --params_file yolov6s_quant/model.pdiparams --run_mode=trt_int8
```
## 性能对比
| 模型 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP16
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) |
| :-------- |:-------- |:--------: | :---------------------: |
| YOLOv6s | 9.06ms | 2.90ms | 1.83ms |
环境:
-
Tesla T4,TensorRT 8.4.1,CUDA 11.2
-
batch_size=1
example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh
0 → 100644
浏览文件 @
d3d78272
#!/bin/bash
set
+x
set
-e
work_path
=
$(
dirname
$(
readlink
-f
$0
))
mkdir
-p
build
cd
build
rm
-rf
*
DEMO_NAME
=
trt_run
WITH_MKL
=
ON
WITH_GPU
=
ON
USE_TENSORRT
=
ON
LIB_DIR
=
/root/auto_compress/Paddle/build/paddle_inference_install_dir/
CUDNN_LIB
=
/usr/lib/x86_64-linux-gnu/
CUDA_LIB
=
/usr/local/cuda/lib64
TENSORRT_ROOT
=
/root/auto_compress/trt/trt8.4/
WITH_ROCM
=
OFF
ROCM_LIB
=
/opt/rocm/lib
cmake ..
-DPADDLE_LIB
=
${
LIB_DIR
}
\
-DWITH_MKL
=
${
WITH_MKL
}
\
-DDEMO_NAME
=
${
DEMO_NAME
}
\
-DWITH_GPU
=
${
WITH_GPU
}
\
-DWITH_STATIC_LIB
=
OFF
\
-DUSE_TENSORRT
=
${
USE_TENSORRT
}
\
-DWITH_ROCM
=
${
WITH_ROCM
}
\
-DROCM_LIB
=
${
ROCM_LIB
}
\
-DCUDNN_LIB
=
${
CUDNN_LIB
}
\
-DCUDA_LIB
=
${
CUDA_LIB
}
\
-DTENSORRT_ROOT
=
${
TENSORRT_ROOT
}
make
-j
example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc
0 → 100644
浏览文件 @
d3d78272
#include <chrono>
#include <iostream>
#include <memory>
#include <numeric>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <cuda_runtime.h>
#include "paddle/include/paddle_inference_api.h"
#include "paddle/include/experimental/phi/common/float16.h"
using
paddle_infer
::
Config
;
using
paddle_infer
::
Predictor
;
using
paddle_infer
::
CreatePredictor
;
using
paddle_infer
::
PrecisionType
;
using
phi
::
dtype
::
float16
;
DEFINE_string
(
model_dir
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
model_file
,
""
,
"Path of the inference model file."
);
DEFINE_string
(
params_file
,
""
,
"Path of the inference params file."
);
DEFINE_string
(
run_mode
,
"trt_fp32"
,
"run_mode which can be: trt_fp32, trt_fp16 and trt_int8"
);
DEFINE_int32
(
batch_size
,
1
,
"Batch size."
);
DEFINE_int32
(
gpu_id
,
0
,
"GPU card ID num."
);
DEFINE_int32
(
trt_min_subgraph_size
,
3
,
"tensorrt min_subgraph_size"
);
DEFINE_int32
(
warmup
,
50
,
"warmup"
);
DEFINE_int32
(
repeats
,
1000
,
"repeats"
);
using
Time
=
decltype
(
std
::
chrono
::
high_resolution_clock
::
now
());
Time
time
()
{
return
std
::
chrono
::
high_resolution_clock
::
now
();
};
double
time_diff
(
Time
t1
,
Time
t2
)
{
typedef
std
::
chrono
::
microseconds
ms
;
auto
diff
=
t2
-
t1
;
ms
counter
=
std
::
chrono
::
duration_cast
<
ms
>
(
diff
);
return
counter
.
count
()
/
1000.0
;
}
std
::
shared_ptr
<
Predictor
>
InitPredictor
()
{
Config
config
;
std
::
string
model_path
;
if
(
FLAGS_model_dir
!=
""
)
{
config
.
SetModel
(
FLAGS_model_dir
);
model_path
=
FLAGS_model_dir
.
substr
(
0
,
FLAGS_model_dir
.
find_last_of
(
"/"
));
}
else
{
config
.
SetModel
(
FLAGS_model_file
,
FLAGS_params_file
);
model_path
=
FLAGS_model_file
.
substr
(
0
,
FLAGS_model_file
.
find_last_of
(
"/"
));
}
// enable tune
std
::
cout
<<
"model_path: "
<<
model_path
<<
std
::
endl
;
config
.
EnableUseGpu
(
256
,
FLAGS_gpu_id
);
if
(
FLAGS_run_mode
==
"trt_fp32"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kFloat32
,
false
,
false
);
}
else
if
(
FLAGS_run_mode
==
"trt_fp16"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kHalf
,
false
,
false
);
}
else
if
(
FLAGS_run_mode
==
"trt_int8"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kInt8
,
false
,
false
);
}
config
.
EnableMemoryOptim
();
config
.
SwitchIrOptim
(
true
);
return
CreatePredictor
(
config
);
}
template
<
typename
type
>
void
run
(
Predictor
*
predictor
,
const
std
::
vector
<
type
>
&
input
,
const
std
::
vector
<
int
>
&
input_shape
,
type
*
out_data
,
std
::
vector
<
int
>
out_shape
)
{
// prepare input
int
input_num
=
std
::
accumulate
(
input_shape
.
begin
(),
input_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
auto
input_names
=
predictor
->
GetInputNames
();
auto
input_t
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
input_t
->
Reshape
(
input_shape
);
input_t
->
CopyFromCpu
(
input
.
data
());
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
CHECK
(
predictor
->
Run
());
auto
st
=
time
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
auto
input_names
=
predictor
->
GetInputNames
();
auto
input_t
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
input_t
->
Reshape
(
input_shape
);
input_t
->
CopyFromCpu
(
input
.
data
());
CHECK
(
predictor
->
Run
());
auto
output_names
=
predictor
->
GetOutputNames
();
auto
output_t
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_t
->
shape
();
output_t
->
ShareExternalData
<
type
>
(
out_data
,
out_shape
,
paddle_infer
::
PlaceType
::
kGPU
);
}
LOG
(
INFO
)
<<
"["
<<
FLAGS_run_mode
<<
" bs-"
<<
FLAGS_batch_size
<<
" ] run avg time is "
<<
time_diff
(
st
,
time
())
/
FLAGS_repeats
<<
" ms"
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
auto
predictor
=
InitPredictor
();
std
::
vector
<
int
>
input_shape
=
{
FLAGS_batch_size
,
3
,
640
,
640
};
// float16
using
dtype
=
float16
;
std
::
vector
<
dtype
>
input_data
(
FLAGS_batch_size
*
3
*
640
*
640
,
dtype
(
1.0
));
dtype
*
out_data
;
int
out_data_size
=
FLAGS_batch_size
*
8400
*
85
;
cudaHostAlloc
((
void
**
)
&
out_data
,
sizeof
(
float
)
*
out_data_size
,
cudaHostAllocMapped
);
std
::
vector
<
int
>
out_shape
{
FLAGS_batch_size
,
1
,
8400
,
85
};
run
<
dtype
>
(
predictor
.
get
(),
input_data
,
input_shape
,
out_data
,
out_shape
);
return
0
;
}
example/auto_compression/pytorch_yolov6/eval.py
0 → 100644
浏览文件 @
d3d78272
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
post_process
import
YOLOv6PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval
():
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
static
.
load_inference_model
(
global_config
[
"model_dir"
],
exe
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
])
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
global_config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
val_program
,
feed
=
data_input
,
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv6'
:
postprocess
=
YOLOv6PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.65
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
metric
.
reset
()
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
eval
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
example/auto_compression/pytorch_yolov6/images/000000570688.jpg
0 → 100644
浏览文件 @
d3d78272
135.1 KB
example/auto_compression/pytorch_yolov6/paddle_trt_infer.py
0 → 100644
浏览文件 @
d3d78272
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
numpy
as
np
import
argparse
import
time
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
post_process
import
YOLOv6PostProcess
CLASS_LABEL
=
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
def
generate_scale
(
im
,
target_shape
,
keep_ratio
=
True
):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape
=
im
.
shape
[:
2
]
if
keep_ratio
:
im_size_min
=
np
.
min
(
origin_shape
)
im_size_max
=
np
.
max
(
origin_shape
)
target_size_min
=
np
.
min
(
target_shape
)
target_size_max
=
np
.
max
(
target_shape
)
im_scale
=
float
(
target_size_min
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
target_size_max
:
im_scale
=
float
(
target_size_max
)
/
float
(
im_size_max
)
im_scale_x
=
im_scale
im_scale_y
=
im_scale
else
:
resize_h
,
resize_w
=
target_shape
im_scale_y
=
resize_h
/
float
(
origin_shape
[
0
])
im_scale_x
=
resize_w
/
float
(
origin_shape
[
1
])
return
im_scale_y
,
im_scale_x
def
image_preprocess
(
img_path
,
target_shape
):
img
=
cv2
.
imread
(
img_path
)
# Resize
im_scale_y
,
im_scale_x
=
generate_scale
(
img
,
target_shape
)
img
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
# Pad
im_h
,
im_w
=
img
.
shape
[:
2
]
h
,
w
=
target_shape
[:]
if
h
!=
im_h
or
w
!=
im_w
:
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
float32
)
canvas
*=
np
.
array
([
114.0
,
114.0
,
114.0
],
dtype
=
np
.
float32
)
canvas
[
0
:
im_h
,
0
:
im_w
,
:]
=
img
.
astype
(
np
.
float32
)
img
=
canvas
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
np
.
transpose
(
img
,
[
2
,
0
,
1
])
/
255
img
=
np
.
expand_dims
(
img
,
0
)
scale_factor
=
np
.
array
([[
im_scale_y
,
im_scale_x
]])
return
img
.
astype
(
np
.
float32
),
scale_factor
def
get_color_map_list
(
num_classes
):
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
j
=
0
lab
=
i
while
lab
:
color_map
[
i
*
3
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
j
))
j
+=
1
lab
>>=
3
color_map
=
[
color_map
[
i
:
i
+
3
]
for
i
in
range
(
0
,
len
(
color_map
),
3
)]
return
color_map
def
draw_box
(
image_file
,
results
,
class_label
,
threshold
=
0.5
):
srcimg
=
cv2
.
imread
(
image_file
,
1
)
for
i
in
range
(
len
(
results
)):
color_list
=
get_color_map_list
(
len
(
class_label
))
clsid2color
=
{}
classid
,
conf
=
int
(
results
[
i
,
0
]),
results
[
i
,
1
]
if
conf
<
threshold
:
continue
xmin
,
ymin
,
xmax
,
ymax
=
int
(
results
[
i
,
2
]),
int
(
results
[
i
,
3
]),
int
(
results
[
i
,
4
]),
int
(
results
[
i
,
5
])
if
classid
not
in
clsid2color
:
clsid2color
[
classid
]
=
color_list
[
classid
]
color
=
tuple
(
clsid2color
[
classid
])
cv2
.
rectangle
(
srcimg
,
(
xmin
,
ymin
),
(
xmax
,
ymax
),
color
,
thickness
=
2
)
print
(
class_label
[
classid
]
+
': '
+
str
(
round
(
conf
,
3
)))
cv2
.
putText
(
srcimg
,
class_label
[
classid
]
+
':'
+
str
(
round
(
conf
,
3
)),
(
xmin
,
ymin
-
10
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.8
,
(
0
,
255
,
0
),
thickness
=
2
)
return
srcimg
def
load_predictor
(
model_dir
,
run_mode
=
'paddle'
,
batch_size
=
1
,
device
=
'CPU'
,
min_subgraph_size
=
3
,
use_dynamic_shape
=
False
,
trt_min_shape
=
1
,
trt_max_shape
=
1280
,
trt_opt_shape
=
640
,
trt_calib_mode
=
False
,
cpu_threads
=
1
,
enable_mkldnn
=
False
,
enable_mkldnn_bfloat16
=
False
,
delete_shuffle_pass
=
False
):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if
device
!=
'GPU'
and
run_mode
!=
'paddle'
:
raise
ValueError
(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.
format
(
run_mode
,
device
))
config
=
Config
(
os
.
path
.
join
(
model_dir
,
'model.pdmodel'
),
os
.
path
.
join
(
model_dir
,
'model.pdiparams'
))
if
device
==
'GPU'
:
# initial GPU memory(M), device ID
config
.
enable_use_gpu
(
200
,
0
)
# optimize graph and fuse op
config
.
switch_ir_optim
(
True
)
elif
device
==
'XPU'
:
config
.
enable_lite_engine
()
config
.
enable_xpu
(
10
*
1024
*
1024
)
else
:
config
.
disable_gpu
()
config
.
set_cpu_math_library_num_threads
(
cpu_threads
)
if
enable_mkldnn
:
try
:
# cache 10 different shapes for mkldnn to avoid memory leak
config
.
set_mkldnn_cache_capacity
(
10
)
config
.
enable_mkldnn
()
if
enable_mkldnn_bfloat16
:
config
.
enable_mkldnn_bfloat16
()
except
Exception
as
e
:
print
(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map
=
{
'trt_int8'
:
Config
.
Precision
.
Int8
,
'trt_fp32'
:
Config
.
Precision
.
Float32
,
'trt_fp16'
:
Config
.
Precision
.
Half
}
if
run_mode
in
precision_map
.
keys
():
config
.
enable_tensorrt_engine
(
workspace_size
=
(
1
<<
25
)
*
batch_size
,
max_batch_size
=
batch_size
,
min_subgraph_size
=
min_subgraph_size
,
precision_mode
=
precision_map
[
run_mode
],
use_static
=
False
,
use_calib_mode
=
trt_calib_mode
)
if
use_dynamic_shape
:
min_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_min_shape
,
trt_min_shape
]
}
max_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_max_shape
,
trt_max_shape
]
}
opt_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_opt_shape
,
trt_opt_shape
]
}
config
.
set_trt_dynamic_shape_info
(
min_input_shape
,
max_input_shape
,
opt_input_shape
)
print
(
'trt set dynamic shape done!'
)
# disable print log when predict
config
.
disable_glog_info
()
# enable shared memory
config
.
enable_memory_optim
()
# disable feed, fetch OP, needed by zero_copy_run
config
.
switch_use_feed_fetch_ops
(
False
)
if
delete_shuffle_pass
:
config
.
delete_pass
(
"shuffle_channel_detect_pass"
)
predictor
=
create_predictor
(
config
)
return
predictor
def
predict_image
(
predictor
,
image_file
,
image_shape
=
[
640
,
640
],
warmup
=
1
,
repeats
=
1
,
threshold
=
0.5
,
arch
=
'YOLOv5'
):
img
,
scale_factor
=
image_preprocess
(
image_file
,
image_shape
)
inputs
=
{}
if
arch
==
'YOLOv5'
:
inputs
[
'x2paddle_images'
]
=
img
input_names
=
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
input_names
[
i
]])
for
i
in
range
(
warmup
):
predictor
.
run
()
np_boxes
=
None
predict_time
=
0.
time_min
=
float
(
"inf"
)
time_max
=
float
(
'-inf'
)
for
i
in
range
(
repeats
):
start_time
=
time
.
time
()
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
boxes_tensor
=
predictor
.
get_output_handle
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
end_time
=
time
.
time
()
timed
=
end_time
-
start_time
time_min
=
min
(
time_min
,
timed
)
time_max
=
max
(
time_max
,
timed
)
predict_time
+=
timed
time_avg
=
predict_time
/
repeats
print
(
'Inference time(ms): min={}, max={}, avg={}'
.
format
(
round
(
time_min
*
1000
,
2
),
round
(
time_max
*
1000
,
1
),
round
(
time_avg
*
1000
,
1
)))
postprocess
=
YOLOv6PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.65
,
multi_label
=
True
)
res
=
postprocess
(
np_boxes
,
scale_factor
)
res_img
=
draw_box
(
image_file
,
res
[
'bbox'
],
CLASS_LABEL
,
threshold
=
threshold
)
cv2
.
imwrite
(
'result.jpg'
,
res_img
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--image_file'
,
type
=
str
,
default
=
None
,
help
=
"image path"
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
"inference model filepath"
)
parser
.
add_argument
(
'--benchmark'
,
type
=
bool
,
default
=
False
,
help
=
"Whether run benchmark or not."
)
parser
.
add_argument
(
'--run_mode'
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'GPU'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU"
)
parser
.
add_argument
(
'--img_shape'
,
type
=
int
,
default
=
640
,
help
=
"input_size"
)
args
=
parser
.
parse_args
()
predictor
=
load_predictor
(
args
.
model_path
,
run_mode
=
args
.
run_mode
,
device
=
args
.
device
)
warmup
,
repeats
=
1
,
1
if
args
.
benchmark
:
warmup
,
repeats
=
50
,
100
predict_image
(
predictor
,
args
.
image_file
,
image_shape
=
[
args
.
img_shape
,
args
.
img_shape
],
warmup
=
warmup
,
repeats
=
repeats
)
example/auto_compression/pytorch_yolov6/post_process.py
0 → 100644
浏览文件 @
d3d78272
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
cv2
def
box_area
(
boxes
):
"""
Args:
boxes(np.ndarray): [N, 4]
return: [N]
"""
return
(
boxes
[:,
2
]
-
boxes
[:,
0
])
*
(
boxes
[:,
3
]
-
boxes
[:,
1
])
def
box_iou
(
box1
,
box2
):
"""
Args:
box1(np.ndarray): [N, 4]
box2(np.ndarray): [M, 4]
return: [N, M]
"""
area1
=
box_area
(
box1
)
area2
=
box_area
(
box2
)
lt
=
np
.
maximum
(
box1
[:,
np
.
newaxis
,
:
2
],
box2
[:,
:
2
])
rb
=
np
.
minimum
(
box1
[:,
np
.
newaxis
,
2
:],
box2
[:,
2
:])
wh
=
rb
-
lt
wh
=
np
.
maximum
(
0
,
wh
)
inter
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
iou
=
inter
/
(
area1
[:,
np
.
newaxis
]
+
area2
-
inter
)
return
iou
def
nms
(
boxes
,
scores
,
iou_threshold
):
"""
Non Max Suppression numpy implementation.
args:
boxes(np.ndarray): [N, 4]
scores(np.ndarray): [N, 1]
iou_threshold(float): Threshold of IoU.
"""
idxs
=
scores
.
argsort
()
keep
=
[]
while
idxs
.
size
>
0
:
max_score_index
=
idxs
[
-
1
]
max_score_box
=
boxes
[
max_score_index
][
None
,
:]
keep
.
append
(
max_score_index
)
if
idxs
.
size
==
1
:
break
idxs
=
idxs
[:
-
1
]
other_boxes
=
boxes
[
idxs
]
ious
=
box_iou
(
max_score_box
,
other_boxes
)
idxs
=
idxs
[
ious
[
0
]
<=
iou_threshold
]
keep
=
np
.
array
(
keep
)
return
keep
class
YOLOv6PostProcess
(
object
):
"""
Post process of YOLOv6 network.
args:
score_threshold(float): Threshold to filter out bounding boxes with low
confidence score. If not provided, consider all boxes.
nms_threshold(float): The threshold to be used in NMS.
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
"""
def
__init__
(
self
,
score_threshold
=
0.25
,
nms_threshold
=
0.5
,
multi_label
=
False
,
keep_top_k
=
300
):
self
.
score_threshold
=
score_threshold
self
.
nms_threshold
=
nms_threshold
self
.
multi_label
=
multi_label
self
.
keep_top_k
=
keep_top_k
def
_xywh2xyxy
(
self
,
x
):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
y
=
np
.
copy
(
x
)
y
[:,
0
]
=
x
[:,
0
]
-
x
[:,
2
]
/
2
# top left x
y
[:,
1
]
=
x
[:,
1
]
-
x
[:,
3
]
/
2
# top left y
y
[:,
2
]
=
x
[:,
0
]
+
x
[:,
2
]
/
2
# bottom right x
y
[:,
3
]
=
x
[:,
1
]
+
x
[:,
3
]
/
2
# bottom right y
return
y
def
_non_max_suppression
(
self
,
prediction
):
max_wh
=
4096
# (pixels) minimum and maximum box width and height
nms_top_k
=
30000
cand_boxes
=
prediction
[...,
4
]
>
self
.
score_threshold
# candidates
output
=
[
np
.
zeros
((
0
,
6
))]
*
prediction
.
shape
[
0
]
for
batch_id
,
boxes
in
enumerate
(
prediction
):
# Apply constraints
boxes
=
boxes
[
cand_boxes
[
batch_id
]]
if
not
boxes
.
shape
[
0
]:
continue
# Compute conf (conf = obj_conf * cls_conf)
boxes
[:,
5
:]
*=
boxes
[:,
4
:
5
]
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
convert_box
=
self
.
_xywh2xyxy
(
boxes
[:,
:
4
])
# Detections matrix nx6 (xyxy, conf, cls)
if
self
.
multi_label
:
i
,
j
=
(
boxes
[:,
5
:]
>
self
.
score_threshold
).
nonzero
()
boxes
=
np
.
concatenate
(
(
convert_box
[
i
],
boxes
[
i
,
j
+
5
,
None
],
j
[:,
None
].
astype
(
np
.
float32
)),
axis
=
1
)
else
:
conf
=
np
.
max
(
boxes
[:,
5
:],
axis
=
1
)
j
=
np
.
argmax
(
boxes
[:,
5
:],
axis
=
1
)
re
=
np
.
array
(
conf
.
reshape
(
-
1
)
>
self
.
score_threshold
)
conf
=
conf
.
reshape
(
-
1
,
1
)
j
=
j
.
reshape
(
-
1
,
1
)
boxes
=
np
.
concatenate
((
convert_box
,
conf
,
j
),
axis
=
1
)[
re
]
num_box
=
boxes
.
shape
[
0
]
if
not
num_box
:
continue
elif
num_box
>
nms_top_k
:
boxes
=
boxes
[
boxes
[:,
4
].
argsort
()[::
-
1
][:
nms_top_k
]]
# Batched NMS
c
=
boxes
[:,
5
:
6
]
*
max_wh
clean_boxes
,
scores
=
boxes
[:,
:
4
]
+
c
,
boxes
[:,
4
]
keep
=
nms
(
clean_boxes
,
scores
,
self
.
nms_threshold
)
# limit detection box num
if
keep
.
shape
[
0
]
>
self
.
keep_top_k
:
keep
=
keep
[:
self
.
keep_top_k
]
output
[
batch_id
]
=
boxes
[
keep
]
return
output
def
__call__
(
self
,
outs
,
scale_factor
):
preds
=
self
.
_non_max_suppression
(
outs
)
bboxs
,
box_nums
=
[],
[]
for
i
,
pred
in
enumerate
(
preds
):
if
len
(
pred
.
shape
)
>
2
:
pred
=
np
.
squeeze
(
pred
)
if
len
(
pred
.
shape
)
==
1
:
pred
=
pred
[
np
.
newaxis
,
:]
pred_bboxes
=
pred
[:,
:
4
]
scale_factor
=
np
.
tile
(
scale_factor
[
i
][::
-
1
],
(
1
,
2
))
pred_bboxes
/=
scale_factor
bbox
=
np
.
concatenate
(
[
pred
[:,
-
1
][:,
np
.
newaxis
],
pred
[:,
-
2
][:,
np
.
newaxis
],
pred_bboxes
],
axis
=-
1
)
bboxs
.
append
(
bbox
)
box_num
=
bbox
.
shape
[
0
]
box_nums
.
append
(
box_num
)
bboxs
=
np
.
concatenate
(
bboxs
,
axis
=
0
)
box_nums
=
np
.
array
(
box_nums
)
return
{
'bbox'
:
bboxs
,
'bbox_num'
:
box_nums
}
example/auto_compression/pytorch_yolov6/post_quant.py
0 → 100644
浏览文件 @
d3d78272
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.quant
import
quant_post_static
from
post_process
import
YOLOv6PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'ptq_out'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--algo'
,
type
=
str
,
default
=
'KL'
,
help
=
"post quant algo."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_post_static
(
executor
=
exe
,
model_dir
=
global_config
[
"model_dir"
],
quantize_model_path
=
FLAGS
.
save_dir
,
data_loader
=
train_loader
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
batch_size
=
32
,
batch_nums
=
10
,
algo
=
FLAGS
.
algo
,
hist_percent
=
0.999
,
is_full_quantize
=
False
,
bias_correction
=
False
,
onnx_format
=
False
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
example/auto_compression/pytorch_yolov6/run.py
0 → 100644
浏览文件 @
d3d78272
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
post_process
import
YOLOv6PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--eval'
,
type
=
bool
,
default
=
False
,
help
=
"whether to run evaluation."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
compiled_test_program
,
feed
=
data_input
,
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
res
=
{}
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv6'
:
postprocess
=
YOLOv6PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.65
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
map_res
=
metric
.
get_results
()
metric
.
reset
()
return
map_res
[
'bbox'
][
0
]
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
if
'Evaluation'
in
global_config
.
keys
()
and
global_config
[
'Evaluation'
]
and
paddle
.
distributed
.
get_rank
()
==
0
:
eval_func
=
eval_function
dataset
=
reader_cfg
[
'EvalDataset'
]
global
val_loader
_eval_batch_sampler
=
paddle
.
io
.
BatchSampler
(
dataset
,
batch_size
=
reader_cfg
[
'EvalReader'
][
'batch_size'
])
val_loader
=
create
(
'EvalReader'
)(
dataset
,
reader_cfg
[
'worker_num'
],
batch_sampler
=
_eval_batch_sampler
,
return_list
=
True
)
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
else
:
eval_func
=
None
ac
=
AutoCompression
(
model_dir
=
global_config
[
"model_dir"
],
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
save_dir
=
FLAGS
.
save_dir
,
config
=
all_config
,
train_dataloader
=
train_loader
,
eval_callback
=
eval_func
)
ac
.
compress
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录