Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d5d27946
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d5d27946
编写于
7月 08, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
7月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update yolov5s act demo (#1274)
上级
f3f5467f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
624 addition
and
29 deletion
+624
-29
example/auto_compression/pytorch_yolov5/README.md
example/auto_compression/pytorch_yolov5/README.md
+30
-9
example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml
...o_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml
+1
-1
example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt
.../auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt
+263
-0
example/auto_compression/pytorch_yolov5/cpp_infer/README.md
example/auto_compression/pytorch_yolov5/cpp_infer/README.md
+62
-0
example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh
example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh
+37
-0
example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc
example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc
+116
-0
example/auto_compression/pytorch_yolov5/eval.py
example/auto_compression/pytorch_yolov5/eval.py
+1
-9
example/auto_compression/pytorch_yolov5/post_quant.py
example/auto_compression/pytorch_yolov5/post_quant.py
+104
-0
example/auto_compression/pytorch_yolov5/run.py
example/auto_compression/pytorch_yolov5/run.py
+0
-10
example/auto_compression/pytorch_yolov6/README.md
example/auto_compression/pytorch_yolov6/README.md
+5
-0
example/auto_compression/pytorch_yolov6/post_quant.py
example/auto_compression/pytorch_yolov6/post_quant.py
+5
-0
未找到文件。
example/auto_compression/pytorch_yolov5/README.md
浏览文件 @
d5d27946
# 目标检测模型自动压缩示例
#
YOLOv5
目标检测模型自动压缩示例
目录:
-
[
1.简介
](
#1简介
)
...
...
@@ -22,12 +22,14 @@
| 模型 | 策略 | 输入尺寸 | mAP
<sup>
val
<br>
0.5:0.95 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP16
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640
*
640 | 37.4 | 7.8ms | 4.3ms | - | - |
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar
)
|
| YOLOv5s | 量化+蒸馏 | 640
*
640 | 36.8 | - | - | 3.4ms |
[
config
](
./configs/yolov5s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar
)
|
| YOLOv5s | Base模型 | 640
*
640 | 37.4 | 5.95ms | 2.44ms | - | - |
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar
)
|
| YOLOv5s | KL离线量化 | 640
*
640 | 36.0 | - | - | 1.87ms | - | - |
| YOLOv5s | 量化蒸馏训练 | 640
*640 | **36.9** | - | - | **1.87ms*
*
|
[
config
](
./configs/yolov5s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar
)
|
说明:
-
mAP的指标均在COCO val2017数据集中评测得到。
-
YOLOv5s模型在Tesla T4的GPU环境下
测试,并且开启TensorRT,测试脚本是
[
benchmark demo
](
./paddle_trt_infer.py
)
。
-
YOLOv5s模型在Tesla T4的GPU环境下
开启TensorRT 8.4.1,batch_size=1, 测试脚本是
[
cpp_infer
](
./cpp_infer
)
。
## 3. 自动压缩流程
...
...
@@ -67,18 +69,19 @@ pip install x2paddle sympy onnx
本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考
[
PaddleDetection数据准备文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md
)
来准备数据。
如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中
`EvalDataset`
的
`dataset_dir`
字段为自己数据集路径即可。
#### 3.3 准备预测模型
(1)准备ONNX模型:
可通过
[
ultralytics/yolov5
](
https://github.com/ultralytics/yolov5
)
官方的
[
导出教程
](
https://github.com/ultralytics/yolov5/issues/251
)
来准备ONNX模型。也可以下载准备好的
[
yolov5s.onnx
](
https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx
)
。
```
```
shell
python export.py
--weights
yolov5s.pt
--include
onnx
```
(2) 转换模型:
```
```
shell
x2paddle
--framework
=
onnx
--model
=
yolov5s.onnx
--save_dir
=
pd_model
cp
-r
pd_model/inference_model/ yolov5s_infer
```
...
...
@@ -112,15 +115,33 @@ export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/yolov5s_qat_dis.yaml
```
**注意**
:
要测试的
模型路径需要在配置文件中
`model_dir`
字段下进行修改指定。
**注意**
:
如果要测试量化后的模型,
模型路径需要在配置文件中
`model_dir`
字段下进行修改指定。
## 4.预测部署
-
Paddle-TensorRT部署:
使用
[
paddle_trt_infer.py
](
./paddle_trt_infer.py
)
进行部署:
#### Paddle-TensorRT C++部署
进入
[
cpp_infer
](
./cpp_infer
)
文件夹内,请按照
[
C++ TensorRT Benchmark测试教程
](
./cpp_infer/README.md
)
进行准备环境及编译,然后开始测试:
```
shell
# 编译
bash complie.sh
# 执行
./build/trt_run
--model_file
yolov5s_quant/model.pdmodel
--params_file
yolov5s_quant/model.pdiparams
--run_mode
=
trt_int8
```
#### Paddle-TensorRT Python部署:
首先安装带有TensorRT的
[
Paddle安装包
](
https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python
)
。
然后使用
[
paddle_trt_infer.py
](
./paddle_trt_infer.py
)
进行部署:
```
shell
python paddle_trt_infer.py
--model_path
=
output
--image_file
=
images/000000570688.jpg
--benchmark
=
True
--run_mode
=
trt_int8
```
## 5.FAQ
-
如果想测试离线量化模型精度,可执行:
```
shell
python post_quant.py
--config_path
=
./configs/yolov5s_qat_dis.yaml
```
example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml
浏览文件 @
d5d27946
...
...
@@ -10,7 +10,7 @@ Global:
Distillation
:
alpha
:
1.0
loss
:
l2
loss
:
soft_label
Quantization
:
use_pact
:
true
...
...
example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt
0 → 100644
浏览文件 @
d5d27946
cmake_minimum_required
(
VERSION 3.0
)
project
(
cpp_inference_demo CXX C
)
option
(
WITH_MKL
"Compile demo with MKL/OpenBlas support, default use MKL."
ON
)
option
(
WITH_GPU
"Compile demo with GPU/CPU, default use CPU."
OFF
)
option
(
WITH_STATIC_LIB
"Compile demo with static/shared library, default use static."
ON
)
option
(
USE_TENSORRT
"Compile demo with TensorRT."
OFF
)
option
(
WITH_ROCM
"Compile demo with rocm."
OFF
)
option
(
WITH_ONNXRUNTIME
"Compile demo with ONNXRuntime"
OFF
)
option
(
WITH_ARM
"Compile demo with ARM"
OFF
)
option
(
WITH_MIPS
"Compile demo with MIPS"
OFF
)
option
(
WITH_SW
"Compile demo with SW"
OFF
)
option
(
WITH_XPU
"Compile demow ith xpu"
OFF
)
option
(
WITH_NPU
"Compile demow ith npu"
OFF
)
if
(
NOT WITH_STATIC_LIB
)
add_definitions
(
"-DPADDLE_WITH_SHARED_LIB"
)
else
()
# PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode.
# Set it to empty in static library mode to avoid compilation issues.
add_definitions
(
"/DPD_INFER_DECL="
)
endif
()
macro
(
safe_set_static_flag
)
foreach
(
flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
)
if
(
${
flag_var
}
MATCHES
"/MD"
)
string
(
REGEX REPLACE
"/MD"
"/MT"
${
flag_var
}
"
${${
flag_var
}}
"
)
endif
(
${
flag_var
}
MATCHES
"/MD"
)
endforeach
(
flag_var
)
endmacro
()
if
(
NOT DEFINED PADDLE_LIB
)
message
(
FATAL_ERROR
"please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib"
)
endif
()
if
(
NOT DEFINED DEMO_NAME
)
message
(
FATAL_ERROR
"please set DEMO_NAME with -DDEMO_NAME=demo_name"
)
endif
()
include_directories
(
"
${
PADDLE_LIB
}
/"
)
set
(
PADDLE_LIB_THIRD_PARTY_PATH
"
${
PADDLE_LIB
}
/third_party/install/"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
protobuf/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
glog/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
gflags/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xxhash/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
cryptopp/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/include"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/include"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
protobuf/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
glog/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
gflags/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xxhash/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
cryptopp/lib"
)
link_directories
(
"
${
PADDLE_LIB
}
/paddle/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/lib"
)
if
(
WIN32
)
add_definitions
(
"/DGOOGLE_GLOG_DLL_DECL="
)
option
(
MSVC_STATIC_CRT
"use static C Runtime library by default"
ON
)
if
(
MSVC_STATIC_CRT
)
if
(
WITH_MKL
)
set
(
FLAG_OPENMP
"/openmp"
)
endif
()
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
/bigobj /MTd
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_C_FLAGS_RELEASE
"
${
CMAKE_C_FLAGS_RELEASE
}
/bigobj /MT
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
/bigobj /MTd
${
FLAG_OPENMP
}
"
)
set
(
CMAKE_CXX_FLAGS_RELEASE
"
${
CMAKE_CXX_FLAGS_RELEASE
}
/bigobj /MT
${
FLAG_OPENMP
}
"
)
safe_set_static_flag
()
if
(
WITH_STATIC_LIB
)
add_definitions
(
-DSTATIC_LIB
)
endif
()
endif
()
else
()
if
(
WITH_MKL
)
set
(
FLAG_OPENMP
"-fopenmp"
)
endif
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11
${
FLAG_OPENMP
}
"
)
endif
()
if
(
WITH_GPU
)
if
(
NOT WIN32
)
include_directories
(
"/usr/local/cuda/include"
)
if
(
CUDA_LIB STREQUAL
""
)
set
(
CUDA_LIB
"/usr/local/cuda/lib64/"
CACHE STRING
"CUDA Library"
)
endif
()
else
()
include_directories
(
"C:
\\
Program
\
Files
\\
NVIDIA GPU Computing Toolkit
\\
CUDA
\\
v8.0
\\
include"
)
if
(
CUDA_LIB STREQUAL
""
)
set
(
CUDA_LIB
"C:
\\
Program
\
Files
\\
NVIDIA GPU Computing Toolkit
\\
CUDA
\\
v8.0
\\
lib
\\
x64"
)
endif
()
endif
(
NOT WIN32
)
endif
()
if
(
USE_TENSORRT AND WITH_GPU
)
set
(
TENSORRT_ROOT
""
CACHE STRING
"The root directory of TensorRT library"
)
if
(
"
${
TENSORRT_ROOT
}
"
STREQUAL
""
)
message
(
FATAL_ERROR
"The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH "
)
endif
()
set
(
TENSORRT_INCLUDE_DIR
${
TENSORRT_ROOT
}
/include
)
set
(
TENSORRT_LIB_DIR
${
TENSORRT_ROOT
}
/lib
)
file
(
READ
${
TENSORRT_INCLUDE_DIR
}
/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS
)
string
(
REGEX MATCH
"define NV_TENSORRT_MAJOR +([0-9]+)"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_VERSION_FILE_CONTENTS
}
"
)
if
(
"
${
TENSORRT_MAJOR_VERSION
}
"
STREQUAL
""
)
file
(
READ
${
TENSORRT_INCLUDE_DIR
}
/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS
)
string
(
REGEX MATCH
"define NV_TENSORRT_MAJOR +([0-9]+)"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_VERSION_FILE_CONTENTS
}
"
)
endif
()
if
(
"
${
TENSORRT_MAJOR_VERSION
}
"
STREQUAL
""
)
message
(
SEND_ERROR
"Failed to detect TensorRT version."
)
endif
()
string
(
REGEX REPLACE
"define NV_TENSORRT_MAJOR +([0-9]+)"
"
\\
1"
TENSORRT_MAJOR_VERSION
"
${
TENSORRT_MAJOR_VERSION
}
"
)
message
(
STATUS
"Current TensorRT header is
${
TENSORRT_INCLUDE_DIR
}
/NvInfer.h. "
"Current TensorRT version is v
${
TENSORRT_MAJOR_VERSION
}
. "
)
include_directories
(
"
${
TENSORRT_INCLUDE_DIR
}
"
)
link_directories
(
"
${
TENSORRT_LIB_DIR
}
"
)
endif
()
if
(
WITH_MKL
)
set
(
MATH_LIB_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mklml"
)
include_directories
(
"
${
MATH_LIB_PATH
}
/include"
)
if
(
WIN32
)
set
(
MATH_LIB
${
MATH_LIB_PATH
}
/lib/mklml
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
${
MATH_LIB_PATH
}
/lib/libiomp5md
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
MATH_LIB
${
MATH_LIB_PATH
}
/lib/libmklml_intel
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
MATH_LIB_PATH
}
/lib/libiomp5
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
set
(
MKLDNN_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mkldnn"
)
if
(
EXISTS
${
MKLDNN_PATH
}
)
include_directories
(
"
${
MKLDNN_PATH
}
/include"
)
if
(
WIN32
)
set
(
MKLDNN_LIB
${
MKLDNN_PATH
}
/lib/mkldnn.lib
)
else
(
WIN32
)
set
(
MKLDNN_LIB
${
MKLDNN_PATH
}
/lib/libmkldnn.so.0
)
endif
(
WIN32
)
endif
()
elseif
((
NOT WITH_MIPS
)
AND
(
NOT WITH_SW
))
set
(
OPENBLAS_LIB_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
openblas"
)
include_directories
(
"
${
OPENBLAS_LIB_PATH
}
/include/openblas"
)
if
(
WIN32
)
set
(
MATH_LIB
${
OPENBLAS_LIB_PATH
}
/lib/openblas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
MATH_LIB
${
OPENBLAS_LIB_PATH
}
/lib/libopenblas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_STATIC_LIB
)
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/libpaddle_inference
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
if
(
WIN32
)
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/paddle_inference
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
else
()
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/libpaddle_inference
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_ONNXRUNTIME
)
if
(
WIN32
)
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/onnxruntime.lib paddle2onnx
)
elseif
(
APPLE
)
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx
)
else
()
set
(
DEPS
${
DEPS
}
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx
)
endif
()
endif
()
if
(
NOT WIN32
)
set
(
EXTERNAL_LIB
"-lrt -ldl -lpthread"
)
set
(
DEPS
${
DEPS
}
${
MATH_LIB
}
${
MKLDNN_LIB
}
glog gflags protobuf xxhash cryptopp
${
EXTERNAL_LIB
}
)
else
()
set
(
DEPS
${
DEPS
}
${
MATH_LIB
}
${
MKLDNN_LIB
}
glog gflags_static libprotobuf xxhash cryptopp-static
${
EXTERNAL_LIB
}
)
set
(
DEPS
${
DEPS
}
shlwapi.lib
)
endif
(
NOT WIN32
)
if
(
WITH_GPU
)
if
(
NOT WIN32
)
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/libnvinfer_plugin
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/libcudart
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
else
()
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/myelin64_1
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudart
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cublas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudnn
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
endif
()
if
(
WITH_ROCM AND NOT WIN32
)
set
(
DEPS
${
DEPS
}
${
ROCM_LIB
}
/libamdhip64
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
if
(
WITH_XPU AND NOT WIN32
)
set
(
XPU_INSTALL_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xpu"
)
set
(
DEPS
${
DEPS
}
${
XPU_INSTALL_PATH
}
/lib/libxpuapi
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
XPU_INSTALL_PATH
}
/lib/libxpurt
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
if
(
WITH_NPU AND NOT WIN32
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libgraph
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libge_runner
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libascendcl
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libascendcl
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
endif
()
add_executable
(
${
DEMO_NAME
}
${
DEMO_NAME
}
.cc
)
target_link_libraries
(
${
DEMO_NAME
}
${
DEPS
}
)
if
(
WIN32
)
if
(
USE_TENSORRT
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/myelin64_1
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
endif
()
if
(
WITH_MKL
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MATH_LIB_PATH
}
/lib/mklml.dll
${
CMAKE_BINARY_DIR
}
/Release
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MATH_LIB_PATH
}
/lib/libiomp5md.dll
${
CMAKE_BINARY_DIR
}
/Release
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MKLDNN_PATH
}
/lib/mkldnn.dll
${
CMAKE_BINARY_DIR
}
/Release
)
else
()
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
OPENBLAS_LIB_PATH
}
/lib/openblas.dll
${
CMAKE_BINARY_DIR
}
/Release
)
endif
()
if
(
WITH_ONNXRUNTIME
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
PADDLE_LIB_THIRD_PARTY_PATH
}
onnxruntime/lib/onnxruntime.dll
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
PADDLE_LIB_THIRD_PARTY_PATH
}
paddle2onnx/lib/paddle2onnx.dll
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
if
(
NOT WITH_STATIC_LIB
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
"
${
PADDLE_LIB
}
/paddle/lib/paddle_inference.dll"
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
endif
()
example/auto_compression/pytorch_yolov5/cpp_infer/README.md
0 → 100644
浏览文件 @
d5d27946
# YOLOv5 TensorRT Benchmark测试(Linux)
## 环境准备
-
CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。
-
TensorRT:可通过NVIDIA官网下载
[
TensorRT 8.4.1.5
](
https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz
)
或其他版本安装包。
-
Paddle Inference C++预测库:编译develop版本请参考
[
编译文档
](
https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html
)
。编译完成后,会在build目录下生成
`paddle_inference_install_dir`
文件夹,这个就是我们需要的C++预测库文件。
## 编译可执行程序
-
(1)修改
`compile.sh`
中依赖库路径,主要是以下内容:
```
shell
# Paddle Inference预测库路径
LIB_DIR
=
/root/auto_compress/Paddle/build/paddle_inference_install_dir/
# CUDNN路径
CUDNN_LIB
=
/usr/lib/x86_64-linux-gnu/
# CUDA路径
CUDA_LIB
=
/usr/local/cuda/lib64
# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹
TENSORRT_ROOT
=
/root/auto_compress/trt/trt8.4/
```
## Paddle TensorRT测试
-
FP32
```
./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp32
```
-
FP16
```
./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp16
```
-
INT8
```
./build/trt_run --model_file yolov5s_quant/model.pdmodel --params_file yolov5s_quant/model.pdiparams --run_mode=trt_int8
```
## 原生TensorRT测试
```
shell
# FP32
trtexec
--onnx
=
yolov5s.onnx
--workspace
=
1024
--avgRuns
=
1000
--inputIOFormats
=
fp16:chw
--outputIOFormats
=
fp16:chw
# FP16
trtexec
--onnx
=
yolov5s.onnx
--workspace
=
1024
--avgRuns
=
1000
--inputIOFormats
=
fp16:chw
--outputIOFormats
=
fp16:chw
--fp16
# INT8
trtexec
--onnx
=
yolov5s.onnx
--workspace
=
1024
--avgRuns
=
1000
--inputIOFormats
=
fp16:chw
--outputIOFormats
=
fp16:chw
--int8
```
## 性能对比
| 预测库 | 模型 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP16
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) |
| :--------: | :--------: |:-------- |:--------: | :---------------------: |
| Paddle TensorRT | yolov5s | 5.95ms | 2.44ms | 1.87ms |
| TensorRT | yolov5s | 6.16ms | 2.58ms | 2.07ms |
环境:
-
Tesla T4,TensorRT 8.4.1,CUDA 11.2
-
batch_size=1
example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh
0 → 100644
浏览文件 @
d5d27946
#!/bin/bash
set
+x
set
-e
work_path
=
$(
dirname
$(
readlink
-f
$0
))
mkdir
-p
build
cd
build
rm
-rf
*
DEMO_NAME
=
trt_run
WITH_MKL
=
ON
WITH_GPU
=
ON
USE_TENSORRT
=
ON
LIB_DIR
=
/root/auto_compress/Paddle/build/paddle_inference_install_dir/
CUDNN_LIB
=
/usr/lib/x86_64-linux-gnu/
CUDA_LIB
=
/usr/local/cuda/lib64
TENSORRT_ROOT
=
/root/auto_compress/trt/trt8.4/
WITH_ROCM
=
OFF
ROCM_LIB
=
/opt/rocm/lib
cmake ..
-DPADDLE_LIB
=
${
LIB_DIR
}
\
-DWITH_MKL
=
${
WITH_MKL
}
\
-DDEMO_NAME
=
${
DEMO_NAME
}
\
-DWITH_GPU
=
${
WITH_GPU
}
\
-DWITH_STATIC_LIB
=
OFF
\
-DUSE_TENSORRT
=
${
USE_TENSORRT
}
\
-DWITH_ROCM
=
${
WITH_ROCM
}
\
-DROCM_LIB
=
${
ROCM_LIB
}
\
-DCUDNN_LIB
=
${
CUDNN_LIB
}
\
-DCUDA_LIB
=
${
CUDA_LIB
}
\
-DTENSORRT_ROOT
=
${
TENSORRT_ROOT
}
make
-j
example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc
0 → 100644
浏览文件 @
d5d27946
#include <chrono>
#include <iostream>
#include <memory>
#include <numeric>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <cuda_runtime.h>
#include "paddle/include/paddle_inference_api.h"
#include "paddle/include/experimental/phi/common/float16.h"
using
paddle_infer
::
Config
;
using
paddle_infer
::
Predictor
;
using
paddle_infer
::
CreatePredictor
;
using
paddle_infer
::
PrecisionType
;
using
phi
::
dtype
::
float16
;
DEFINE_string
(
model_dir
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
model_file
,
""
,
"Path of the inference model file."
);
DEFINE_string
(
params_file
,
""
,
"Path of the inference params file."
);
DEFINE_string
(
run_mode
,
"trt_fp32"
,
"run_mode which can be: trt_fp32, trt_fp16 and trt_int8"
);
DEFINE_int32
(
batch_size
,
1
,
"Batch size."
);
DEFINE_int32
(
gpu_id
,
0
,
"GPU card ID num."
);
DEFINE_int32
(
trt_min_subgraph_size
,
3
,
"tensorrt min_subgraph_size"
);
DEFINE_int32
(
warmup
,
50
,
"warmup"
);
DEFINE_int32
(
repeats
,
1000
,
"repeats"
);
using
Time
=
decltype
(
std
::
chrono
::
high_resolution_clock
::
now
());
Time
time
()
{
return
std
::
chrono
::
high_resolution_clock
::
now
();
};
double
time_diff
(
Time
t1
,
Time
t2
)
{
typedef
std
::
chrono
::
microseconds
ms
;
auto
diff
=
t2
-
t1
;
ms
counter
=
std
::
chrono
::
duration_cast
<
ms
>
(
diff
);
return
counter
.
count
()
/
1000.0
;
}
std
::
shared_ptr
<
Predictor
>
InitPredictor
()
{
Config
config
;
std
::
string
model_path
;
if
(
FLAGS_model_dir
!=
""
)
{
config
.
SetModel
(
FLAGS_model_dir
);
model_path
=
FLAGS_model_dir
.
substr
(
0
,
FLAGS_model_dir
.
find_last_of
(
"/"
));
}
else
{
config
.
SetModel
(
FLAGS_model_file
,
FLAGS_params_file
);
model_path
=
FLAGS_model_file
.
substr
(
0
,
FLAGS_model_file
.
find_last_of
(
"/"
));
}
// enable tune
std
::
cout
<<
"model_path: "
<<
model_path
<<
std
::
endl
;
config
.
EnableUseGpu
(
256
,
FLAGS_gpu_id
);
if
(
FLAGS_run_mode
==
"trt_fp32"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kFloat32
,
false
,
false
);
}
else
if
(
FLAGS_run_mode
==
"trt_fp16"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kHalf
,
false
,
false
);
}
else
if
(
FLAGS_run_mode
==
"trt_int8"
)
{
config
.
EnableTensorRtEngine
(
1
<<
30
,
FLAGS_batch_size
,
FLAGS_trt_min_subgraph_size
,
PrecisionType
::
kInt8
,
false
,
false
);
}
config
.
EnableMemoryOptim
();
config
.
SwitchIrOptim
(
true
);
return
CreatePredictor
(
config
);
}
template
<
typename
type
>
void
run
(
Predictor
*
predictor
,
const
std
::
vector
<
type
>
&
input
,
const
std
::
vector
<
int
>
&
input_shape
,
type
*
out_data
,
std
::
vector
<
int
>
out_shape
)
{
// prepare input
int
input_num
=
std
::
accumulate
(
input_shape
.
begin
(),
input_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
auto
input_names
=
predictor
->
GetInputNames
();
auto
input_t
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
input_t
->
Reshape
(
input_shape
);
input_t
->
CopyFromCpu
(
input
.
data
());
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
CHECK
(
predictor
->
Run
());
auto
st
=
time
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
auto
input_names
=
predictor
->
GetInputNames
();
auto
input_t
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
input_t
->
Reshape
(
input_shape
);
input_t
->
CopyFromCpu
(
input
.
data
());
CHECK
(
predictor
->
Run
());
auto
output_names
=
predictor
->
GetOutputNames
();
auto
output_t
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_t
->
shape
();
output_t
->
ShareExternalData
<
type
>
(
out_data
,
out_shape
,
paddle_infer
::
PlaceType
::
kGPU
);
}
LOG
(
INFO
)
<<
"["
<<
FLAGS_run_mode
<<
" bs-"
<<
FLAGS_batch_size
<<
" ] run avg time is "
<<
time_diff
(
st
,
time
())
/
FLAGS_repeats
<<
" ms"
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
auto
predictor
=
InitPredictor
();
std
::
vector
<
int
>
input_shape
=
{
FLAGS_batch_size
,
3
,
640
,
640
};
// float16
using
dtype
=
float16
;
std
::
vector
<
dtype
>
input_data
(
FLAGS_batch_size
*
3
*
640
*
640
,
dtype
(
1.0
));
dtype
*
out_data
;
int
out_data_size
=
FLAGS_batch_size
*
25200
*
85
;
cudaHostAlloc
((
void
**
)
&
out_data
,
sizeof
(
float
)
*
out_data_size
,
cudaHostAllocMapped
);
std
::
vector
<
int
>
out_shape
{
FLAGS_batch_size
,
1
,
25200
,
85
};
run
<
dtype
>
(
predictor
.
get
(),
input_data
,
input_shape
,
out_data
,
out_shape
);
return
0
;
}
example/auto_compression/pytorch_yolov5/eval.py
浏览文件 @
d5d27946
...
...
@@ -42,13 +42,6 @@ def argsparser():
return
parser
def
print_arguments
(
args
):
print
(
'----------- Running Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------'
)
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
...
...
@@ -84,7 +77,7 @@ def eval():
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
fluid
.
io
.
load_inference_model
(
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
static
.
load_inference_model
(
global_config
[
"model_dir"
],
exe
,
model_filename
=
global_config
[
"model_filename"
],
...
...
@@ -160,7 +153,6 @@ if __name__ == '__main__':
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
...
...
example/auto_compression/pytorch_yolov5/post_quant.py
0 → 100644
浏览文件 @
d5d27946
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.quant
import
quant_post_static
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'ptq_out'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--algo'
,
type
=
str
,
default
=
'KL'
,
help
=
"post quant algo."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
main
():
global
global_config
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_post_static
(
executor
=
exe
,
model_dir
=
global_config
[
"model_dir"
],
quantize_model_path
=
FLAGS
.
save_dir
,
data_loader
=
train_loader
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
batch_size
=
32
,
batch_nums
=
10
,
algo
=
FLAGS
.
algo
,
hist_percent
=
0.999
,
is_full_quantize
=
False
,
bias_correction
=
False
,
onnx_format
=
False
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
example/auto_compression/pytorch_yolov5/run.py
浏览文件 @
d5d27946
...
...
@@ -44,19 +44,10 @@ def argsparser():
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--eval'
,
type
=
bool
,
default
=
False
,
help
=
"whether to run evaluation."
)
return
parser
def
print_arguments
(
args
):
print
(
'----------- Running Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------'
)
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
...
...
@@ -181,7 +172,6 @@ if __name__ == '__main__':
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
print_arguments
(
FLAGS
)
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
...
...
example/auto_compression/pytorch_yolov6/README.md
浏览文件 @
d5d27946
...
...
@@ -136,3 +136,8 @@ python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.
```
## 5.FAQ
-
如果想测试离线量化模型精度,可执行:
```
shell
python post_quant.py
--config_path
=
./configs/yolov6s_qat_dis.yaml
```
example/auto_compression/pytorch_yolov6/post_quant.py
浏览文件 @
d5d27946
...
...
@@ -39,6 +39,11 @@ def argsparser():
type
=
str
,
default
=
'ptq_out'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--algo'
,
type
=
str
,
default
=
'KL'
,
help
=
"post quant algo."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录