未验证 提交 ae630a78 编写于 作者: G Guanghua Yu 提交者: GitHub

add YOLOv7 ACT example (#1291)

上级 61ad6665
......@@ -20,7 +20,7 @@ PaddleSlim是一个专注于深度学习模型压缩的工具库,提供**低
- 支持代码无感知压缩:用户只需提供推理模型文件和数据,既可进行离线量化(PTQ)、量化训练(QAT)、稀疏训练等压缩任务。
- 支持自动策略选择,根据任务特点和部署环境特性:自动搜索合适的离线量化方法,自动搜索最佳的压缩策略组合方式。
- 发布[自然语言处理](example/auto_compression/nlp)、[图像语义分割](example/auto_compression/semantic_segmentation)、[图像目标检测](example/auto_compression/detection)三个方向的自动化压缩示例。
- 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[YOLOv6](example/auto_compression/pytorch_yolov6)、[HuggingFace](example/auto_compression/pytorch_huggingface)、[MobileNet](example/auto_compression/tensorflow_mobilenet)。
- 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[YOLOv6](example/auto_compression/pytorch_yolov6)、[YOLOv7](example/auto_compression/pytorch_yolov7)、[HuggingFace](example/auto_compression/pytorch_huggingface)、[MobileNet](example/auto_compression/tensorflow_mobilenet)。
- 升级量化功能
......
......@@ -82,13 +82,15 @@ ACT相比传统的模型压缩方法,
| [语义分割](./semantic_segmentation) | UNet | 65.00 | 64.93 | 15.29 | 10.23 | **1.49** | NVIDIA Tesla T4 |
| NLP | PP-MiniLM | 72.81 | 72.44 | 128.01 | 17.97 | **7.12** | NVIDIA Tesla T4 |
| NLP | ERNIE 3.0-Medium | 73.09 | 72.40 | 29.25(fp16) | 19.61 | **1.49** | NVIDIA Tesla T4 |
| [目标检测](./detection) | YOLOv5s<br/>(PyTorch) | 37.40 | 36.9 | 5.95 | 1.87 | **3.18** | NVIDIA Tesla T4 |
| [目标检测](./pytorch_yolov5) | YOLOv5s<br/>(PyTorch) | 37.40 | 36.9 | 5.95 | 1.87 | **3.18** | NVIDIA Tesla T4 |
| [目标检测](./pytorch_yolov6) | YOLOv6s<br/>(PyTorch) | 42.4 | 41.3 | 9.06 | 1.83 | **4.95** | NVIDIA Tesla T4 |
| [目标检测](./pytorch_yolov7) | YOLOv7<br/>(PyTorch) | 51.1 | 50.8 | 26.84 | 4.55 | **5.89** | NVIDIA Tesla T4 |
| [目标检测](./detection) | PP-YOLOE-l | 50.9 | 50.6 | 11.2 | 6.7 | **1.67** | NVIDIA Tesla V100 |
| [图像分类](./image_classification) | MobileNetV1<br/>(TensorFlow) | 71.0 | 70.22 | 30.45 | 15.86 | **1.92** | SDMM865(骁龙865) |
- 备注:目标检测精度指标为mAP(0.5:0.95)精度测量结果。图像分割精度指标为IoU精度测量结果。
- 更多飞桨模型应用示例及Benchmark可以参考:[图像分类](./image_classification)[目标检测](./detection)[语义分割](./semantic_segmentation)[自然语言处理](./nlp)
- 更多其它框架应用示例及Benchmark可以参考:[YOLOv5(PyTorch)](./pytorch_yolov5)[HuggingFace(PyTorch)](./pytorch_huggingface)[MobileNet(TensorFlow)](./tensorflow_mobilenet)
- 更多其它框架应用示例及Benchmark可以参考:[YOLOv5(PyTorch)](./pytorch_yolov5)[YOLOv6(PyTorch)](./pytorch_yolov6)[YOLOv7(PyTorch)](./pytorch_yolov7)[HuggingFace(PyTorch)](./pytorch_huggingface)[MobileNet(TensorFlow)](./tensorflow_mobilenet)
## **环境准备**
......
......@@ -22,7 +22,7 @@
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar) |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar) |
| YOLOv6s | KL离线量化 | 640*640 | 30.3 | - | - | 1.83ms | - | - |
| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) |
......@@ -83,7 +83,7 @@ pip install x2paddle sympy onnx
x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov6s_infer
```
即可得到YOLOv6s模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv6s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar)
即可得到YOLOv6s模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv6s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_infer.tar)
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
......
# YOLOv7自动压缩示例
目录:
- [1.简介](#1简介)
- [2.Benchmark](#2Benchmark)
- [3.开始自动压缩](#自动压缩流程)
- [3.1 环境准备](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型)
- [3.4 测试模型精度](#34-测试模型精度)
- [3.5 自动压缩并产出模型](#35-自动压缩并产出模型)
- [4.预测部署](#4预测部署)
- [5.FAQ](5FAQ)
## 1. 简介
飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,各种框架的推理模型可以很方便的使用PaddleSlim的自动化压缩功能。
本示例将以[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)目标检测模型为例,将PyTorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。
## 2.Benchmark
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar) |
| YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - |
| YOLOv7 | 量化蒸馏训练 | 640*640 | **50.8** | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) |
说明:
- mAP的指标均在COCO val2017数据集中评测得到。
- YOLOv7模型在Tesla T4的GPU环境下开启TensorRT 8.4.1,batch_size=1, 测试脚本是[cpp_infer](./cpp_infer)
## 3. 自动压缩流程
#### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim > 2.3版本
- PaddleDet >= 2.4
- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6
- opencv-python
(1)安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
```
(2)安装paddleslim:
```shell
pip install paddleslim
```
(3)安装paddledet:
```shell
pip install paddledet
```
注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。
(4)安装X2Paddle的1.3.6以上版本:
```shell
pip install x2paddle sympy onnx
```
#### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
如果已经准备好数据集,请直接修改[./configs/yolov7_reader.yml]中`EvalDataset``dataset_dir`字段为自己数据集路径即可。
#### 3.3 准备预测模型
(1)准备ONNX模型:
可通过[WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)的导出脚本来准备ONNX模型,具体步骤如下:
```shell
git clone https://github.com/WongKinYiu/yolov7.git
# 切换分支到u5分支,保持导出的ONNX模型后处理和YOLOv5一致
git checkout u5
# 下载好yolov7.pt权重后执行:
python export.py --weights yolov7.pt --include onnx
```
也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx)
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov7.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov7_infer
```
即可得到YOLOv7模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv7的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_infer.tar)
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
#### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
- 单卡训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/yolov7_qat_dis.yaml --save_dir='./output/'
```
- 多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/yolov7_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
修改[yolov7_qat_dis.yaml](./configs/yolov7_qat_dis.yaml)`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/yolov7_qat_dis.yaml
```
## 4.预测部署
#### Paddle-TensorRT C++部署
进入[cpp_infer](./cpp_infer)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer/README.md)进行准备环境及编译,然后开始测试:
```shell
# 编译
bash complie.sh
# 执行
./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8
```
#### Paddle-TensorRT Python部署:
首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)
然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署:
```shell
python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8
```
## 5.FAQ
- 如果想测试离线量化模型精度,可执行:
```shell
python post_quant.py --config_path=./configs/yolov7_qat_dis.yaml
```
Global:
reader_config: configs/yolov7_reader.yaml
input_list: {'image': 'x2paddle_images'}
Evaluation: True
model_dir: ./yolov7_infer
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation:
alpha: 1.0
loss: soft_label
Quantization:
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 8000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00003
T_max: 8000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 0.00004
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 0
# preprocess reader in test
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: True}
- Pad: {size: [640, 640], fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
cmake_minimum_required(VERSION 3.0)
project(cpp_inference_demo CXX C)
option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON)
option(USE_TENSORRT "Compile demo with TensorRT." OFF)
option(WITH_ROCM "Compile demo with rocm." OFF)
option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF)
option(WITH_ARM "Compile demo with ARM" OFF)
option(WITH_MIPS "Compile demo with MIPS" OFF)
option(WITH_SW "Compile demo with SW" OFF)
option(WITH_XPU "Compile demow ith xpu" OFF)
option(WITH_NPU "Compile demow ith npu" OFF)
if(NOT WITH_STATIC_LIB)
add_definitions("-DPADDLE_WITH_SHARED_LIB")
else()
# PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode.
# Set it to empty in static library mode to avoid compilation issues.
add_definitions("/DPD_INFER_DECL=")
endif()
macro(safe_set_static_flag)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endmacro()
if(NOT DEFINED PADDLE_LIB)
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
endif()
if(NOT DEFINED DEMO_NAME)
message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name")
endif()
include_directories("${PADDLE_LIB}/")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib")
if (WIN32)
add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
if (MSVC_STATIC_CRT)
if (WITH_MKL)
set(FLAG_OPENMP "/openmp")
endif()
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
safe_set_static_flag()
if (WITH_STATIC_LIB)
add_definitions(-DSTATIC_LIB)
endif()
endif()
else()
if(WITH_MKL)
set(FLAG_OPENMP "-fopenmp")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}")
endif()
if(WITH_GPU)
if(NOT WIN32)
include_directories("/usr/local/cuda/include")
if(CUDA_LIB STREQUAL "")
set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
endif()
else()
include_directories("C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include")
if(CUDA_LIB STREQUAL "")
set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64")
endif()
endif(NOT WIN32)
endif()
if (USE_TENSORRT AND WITH_GPU)
set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library")
if("${TENSORRT_ROOT}" STREQUAL "")
message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ")
endif()
set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include)
set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib)
file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
endif()
if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
message(SEND_ERROR "Failed to detect TensorRT version.")
endif()
string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1"
TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}")
message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
include_directories("${TENSORRT_INCLUDE_DIR}")
link_directories("${TENSORRT_LIB_DIR}")
endif()
if(WITH_MKL)
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
include_directories("${MATH_LIB_PATH}/include")
if(WIN32)
set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
if(EXISTS ${MKLDNN_PATH})
include_directories("${MKLDNN_PATH}/include")
if(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
else(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
endif(WIN32)
endif()
elseif((NOT WITH_MIPS) AND (NOT WITH_SW))
set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas")
include_directories("${OPENBLAS_LIB_PATH}/include/openblas")
if(WIN32)
set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
endif()
endif()
if(WITH_STATIC_LIB)
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
if(WIN32)
set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
endif()
if (WITH_ONNXRUNTIME)
if(WIN32)
set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib paddle2onnx)
elseif(APPLE)
set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx)
else()
set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx)
endif()
endif()
if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB})
else()
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags_static libprotobuf xxhash cryptopp-static ${EXTERNAL_LIB})
set(DEPS ${DEPS} shlwapi.lib)
endif(NOT WIN32)
if(WITH_GPU)
if(NOT WIN32)
if (USE_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
else()
if(USE_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX})
endif()
endif()
set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} )
set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} )
endif()
endif()
if(WITH_ROCM AND NOT WIN32)
set(DEPS ${DEPS} ${ROCM_LIB}/libamdhip64${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if(WITH_XPU AND NOT WIN32)
set(XPU_INSTALL_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}xpu")
set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if(WITH_NPU AND NOT WIN32)
set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
target_link_libraries(${DEMO_NAME} ${DEPS})
if(WIN32)
if(USE_TENSORRT)
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX}
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE})
endif()
endif()
if(WITH_MKL)
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release
COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release
)
else()
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release
)
endif()
if(WITH_ONNXRUNTIME)
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll
${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
)
endif()
if(NOT WITH_STATIC_LIB)
add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
)
endif()
endif()
# YOLOv7 TensorRT Benchmark测试(Linux)
## 环境准备
- CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。
- TensorRT:可通过NVIDIA官网下载[TensorRT 8.4.1.5](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz)或其他版本安装包。
- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件。
## 编译可执行程序
- (1)修改`compile.sh`中依赖库路径,主要是以下内容:
```shell
# Paddle Inference预测库路径
LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/
# CUDNN路径
CUDNN_LIB=/usr/lib/x86_64-linux-gnu/
# CUDA路径
CUDA_LIB=/usr/local/cuda/lib64
# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹
TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/
```
## 测试
- FP32
```
./build/trt_run --model_file yolov7_infer/model.pdmodel --params_file yolov7_infer/model.pdiparams --run_mode=trt_fp32
```
- FP16
```
./build/trt_run --model_file yolov7_infer/model.pdmodel --params_file yolov7_infer/model.pdiparams --run_mode=trt_fp16
```
- INT8
```
./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8
```
## 性能对比
| 预测库 | 模型 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP16</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) |
| :--------: | :--------: |:-------- |:--------: | :---------------------: |
| Paddle TensorRT | YOLOv7 | 26.84ms | 7.44ms | 4.55ms |
| TensorRT | YOLOv7 | 28.25ms | 7.23ms | 4.67ms |
环境:
- Tesla T4,TensorRT 8.4.1,CUDA 11.2
- batch_size=1
#!/bin/bash
set +x
set -e
work_path=$(dirname $(readlink -f $0))
mkdir -p build
cd build
rm -rf *
DEMO_NAME=trt_run
WITH_MKL=ON
WITH_GPU=ON
USE_TENSORRT=ON
LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/
CUDNN_LIB=/usr/lib/x86_64-linux-gnu/
CUDA_LIB=/usr/local/cuda/lib64
TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/
WITH_ROCM=OFF
ROCM_LIB=/opt/rocm/lib
cmake .. -DPADDLE_LIB=${LIB_DIR} \
-DWITH_MKL=${WITH_MKL} \
-DDEMO_NAME=${DEMO_NAME} \
-DWITH_GPU=${WITH_GPU} \
-DWITH_STATIC_LIB=OFF \
-DUSE_TENSORRT=${USE_TENSORRT} \
-DWITH_ROCM=${WITH_ROCM} \
-DROCM_LIB=${ROCM_LIB} \
-DCUDNN_LIB=${CUDNN_LIB} \
-DCUDA_LIB=${CUDA_LIB} \
-DTENSORRT_ROOT=${TENSORRT_ROOT}
make -j
#include <chrono>
#include <iostream>
#include <memory>
#include <numeric>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <cuda_runtime.h>
#include "paddle/include/paddle_inference_api.h"
#include "paddle/include/experimental/phi/common/float16.h"
using paddle_infer::Config;
using paddle_infer::Predictor;
using paddle_infer::CreatePredictor;
using paddle_infer::PrecisionType;
using phi::dtype::float16;
DEFINE_string(model_dir, "", "Directory of the inference model.");
DEFINE_string(model_file, "", "Path of the inference model file.");
DEFINE_string(params_file, "", "Path of the inference params file.");
DEFINE_string(run_mode, "trt_fp32", "run_mode which can be: trt_fp32, trt_fp16 and trt_int8");
DEFINE_int32(batch_size, 1, "Batch size.");
DEFINE_int32(gpu_id, 0, "GPU card ID num.");
DEFINE_int32(trt_min_subgraph_size, 3, "tensorrt min_subgraph_size");
DEFINE_int32(warmup, 50, "warmup");
DEFINE_int32(repeats, 1000, "repeats");
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); };
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
std::shared_ptr<Predictor> InitPredictor() {
Config config;
std::string model_path;
if (FLAGS_model_dir != "") {
config.SetModel(FLAGS_model_dir);
model_path = FLAGS_model_dir.substr(0, FLAGS_model_dir.find_last_of("/"));
} else {
config.SetModel(FLAGS_model_file, FLAGS_params_file);
model_path = FLAGS_model_file.substr(0, FLAGS_model_file.find_last_of("/"));
}
// enable tune
std::cout << "model_path: " << model_path << std::endl;
config.EnableUseGpu(256, FLAGS_gpu_id);
if (FLAGS_run_mode == "trt_fp32") {
config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
PrecisionType::kFloat32, false, false);
} else if (FLAGS_run_mode == "trt_fp16") {
config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
PrecisionType::kHalf, false, false);
} else if (FLAGS_run_mode == "trt_int8") {
config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
PrecisionType::kInt8, false, false);
}
config.EnableMemoryOptim();
config.SwitchIrOptim(true);
return CreatePredictor(config);
}
template <typename type>
void run(Predictor *predictor, const std::vector<type> &input,
const std::vector<int> &input_shape, type* out_data, std::vector<int> out_shape) {
// prepare input
int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int>());
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape(input_shape);
input_t->CopyFromCpu(input.data());
for (int i = 0; i < FLAGS_warmup; ++i)
CHECK(predictor->Run());
auto st = time();
for (int i = 0; i < FLAGS_repeats; ++i) {
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape(input_shape);
input_t->CopyFromCpu(input.data());
CHECK(predictor->Run());
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
output_t -> ShareExternalData<type>(out_data, out_shape, paddle_infer::PlaceType::kGPU);
}
LOG(INFO) << "[" << FLAGS_run_mode << " bs-" << FLAGS_batch_size << " ] run avg time is " << time_diff(st, time()) / FLAGS_repeats
<< " ms";
}
int main(int argc, char *argv[]) {
google::ParseCommandLineFlags(&argc, &argv, true);
auto predictor = InitPredictor();
std::vector<int> input_shape = {FLAGS_batch_size, 3, 640, 640};
// float16
using dtype = float16;
std::vector<dtype> input_data(FLAGS_batch_size * 3 * 640 * 640, dtype(1.0));
dtype *out_data;
int out_data_size = FLAGS_batch_size * 25200 * 85;
cudaHostAlloc((void**)&out_data, sizeof(float) * out_data_size, cudaHostAllocMapped);
std::vector<int> out_shape{ FLAGS_batch_size, 1, 25200, 85};
run<dtype>(predictor.get(), input_data, input_shape, out_data, out_shape);
return 0;
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from post_process import YOLOv7PostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"],
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in global_config['input_list']:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
res = {}
postprocess = YOLOv7PostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
dataset = reader_cfg['EvalDataset']
global val_loader
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import argparse
import time
from paddle.inference import Config
from paddle.inference import create_predictor
from post_process import YOLOv7PostProcess
CLASS_LABEL = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
def generate_scale(im, target_shape, keep_ratio=True):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
if keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(target_shape)
target_size_max = np.max(target_shape)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = target_shape
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
def image_preprocess(img_path, target_shape):
img = cv2.imread(img_path)
# Resize
im_scale_y, im_scale_x = generate_scale(img, target_shape)
img = cv2.resize(
img,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
# Pad
im_h, im_w = img.shape[:2]
h, w = target_shape[:]
if h != im_h or w != im_w:
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = img.astype(np.float32)
img = canvas
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, [2, 0, 1]) / 255
img = np.expand_dims(img, 0)
scale_factor = np.array([[im_scale_y, im_scale_x]])
return img.astype(np.float32), scale_factor
def get_color_map_list(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def draw_box(image_file, results, class_label, threshold=0.5):
srcimg = cv2.imread(image_file, 1)
for i in range(len(results)):
color_list = get_color_map_list(len(class_label))
clsid2color = {}
classid, conf = int(results[i, 0]), results[i, 1]
if conf < threshold:
continue
xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int(
results[i, 4]), int(results[i, 5])
if classid not in clsid2color:
clsid2color[classid] = color_list[classid]
color = tuple(clsid2color[classid])
cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2)
print(class_label[classid] + ': ' + str(round(conf, 3)))
cv2.putText(
srcimg,
class_label[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 255, 0),
thickness=2)
return srcimg
def load_predictor(model_dir,
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
if device == 'GPU':
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
elif device == 'XPU':
config.enable_lite_engine()
config.enable_xpu(10 * 1024 * 1024)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {
'image': [batch_size, 3, trt_min_shape, trt_min_shape]
}
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config)
return predictor
def predict_image(predictor,
image_file,
image_shape=[640, 640],
warmup=1,
repeats=1,
threshold=0.5,
arch='YOLOv5'):
img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {}
if arch == 'YOLOv5':
inputs['x2paddle_images'] = img
input_names = predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
predictor.run()
np_boxes = None
predict_time = 0.
time_min = float("inf")
time_max = float('-inf')
for i in range(repeats):
start_time = time.time()
predictor.run()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
time_avg = predict_time / repeats
print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = YOLOv7PostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np_boxes, scale_factor)
res_img = draw_box(
image_file, res['bbox'], CLASS_LABEL, threshold=threshold)
cv2.imwrite('result.jpg', res_img)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_file', type=str, default=None, help="image path")
parser.add_argument(
'--model_path', type=str, help="inference model filepath")
parser.add_argument(
'--benchmark',
type=bool,
default=False,
help="Whether run benchmark or not.")
parser.add_argument(
'--run_mode',
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
'--device',
type=str,
default='GPU',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU"
)
parser.add_argument('--img_shape', type=int, default=640, help="input_size")
args = parser.parse_args()
predictor = load_predictor(
args.model_path, run_mode=args.run_mode, device=args.device)
warmup, repeats = 1, 1
if args.benchmark:
warmup, repeats = 50, 100
predict_image(
predictor,
args.image_file,
image_shape=[args.img_shape, args.img_shape],
warmup=warmup,
repeats=repeats)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
def box_area(boxes):
"""
Args:
boxes(np.ndarray): [N, 4]
return: [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou(box1, box2):
"""
Args:
box1(np.ndarray): [N, 4]
box2(np.ndarray): [M, 4]
return: [N, M]
"""
area1 = box_area(box1)
area2 = box_area(box2)
lt = np.maximum(box1[:, np.newaxis, :2], box2[:, :2])
rb = np.minimum(box1[:, np.newaxis, 2:], box2[:, 2:])
wh = rb - lt
wh = np.maximum(0, wh)
inter = wh[:, :, 0] * wh[:, :, 1]
iou = inter / (area1[:, np.newaxis] + area2 - inter)
return iou
def nms(boxes, scores, iou_threshold):
"""
Non Max Suppression numpy implementation.
args:
boxes(np.ndarray): [N, 4]
scores(np.ndarray): [N, 1]
iou_threshold(float): Threshold of IoU.
"""
idxs = scores.argsort()
keep = []
while idxs.size > 0:
max_score_index = idxs[-1]
max_score_box = boxes[max_score_index][None, :]
keep.append(max_score_index)
if idxs.size == 1:
break
idxs = idxs[:-1]
other_boxes = boxes[idxs]
ious = box_iou(max_score_box, other_boxes)
idxs = idxs[ious[0] <= iou_threshold]
keep = np.array(keep)
return keep
class YOLOv7PostProcess(object):
"""
Post process of YOLOv6 network.
args:
score_threshold(float): Threshold to filter out bounding boxes with low
confidence score. If not provided, consider all boxes.
nms_threshold(float): The threshold to be used in NMS.
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
"""
def __init__(self,
score_threshold=0.25,
nms_threshold=0.5,
multi_label=False,
keep_top_k=300):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.multi_label = multi_label
self.keep_top_k = keep_top_k
def _xywh2xyxy(self, x):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def _non_max_suppression(self, prediction):
max_wh = 4096 # (pixels) minimum and maximum box width and height
nms_top_k = 30000
cand_boxes = prediction[..., 4] > self.score_threshold # candidates
output = [np.zeros((0, 6))] * prediction.shape[0]
for batch_id, boxes in enumerate(prediction):
# Apply constraints
boxes = boxes[cand_boxes[batch_id]]
if not boxes.shape[0]:
continue
# Compute conf (conf = obj_conf * cls_conf)
boxes[:, 5:] *= boxes[:, 4:5]
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
convert_box = self._xywh2xyxy(boxes[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if self.multi_label:
i, j = (boxes[:, 5:] > self.score_threshold).nonzero()
boxes = np.concatenate(
(convert_box[i], boxes[i, j + 5, None],
j[:, None].astype(np.float32)),
axis=1)
else:
conf = np.max(boxes[:, 5:], axis=1)
j = np.argmax(boxes[:, 5:], axis=1)
re = np.array(conf.reshape(-1) > self.score_threshold)
conf = conf.reshape(-1, 1)
j = j.reshape(-1, 1)
boxes = np.concatenate((convert_box, conf, j), axis=1)[re]
num_box = boxes.shape[0]
if not num_box:
continue
elif num_box > nms_top_k:
boxes = boxes[boxes[:, 4].argsort()[::-1][:nms_top_k]]
# Batched NMS
c = boxes[:, 5:6] * max_wh
clean_boxes, scores = boxes[:, :4] + c, boxes[:, 4]
keep = nms(clean_boxes, scores, self.nms_threshold)
# limit detection box num
if keep.shape[0] > self.keep_top_k:
keep = keep[:self.keep_top_k]
output[batch_id] = boxes[keep]
return output
def __call__(self, outs, scale_factor):
preds = self._non_max_suppression(outs)
bboxs, box_nums = [], []
for i, pred in enumerate(preds):
if len(pred.shape) > 2:
pred = np.squeeze(pred)
if len(pred.shape) == 1:
pred = pred[np.newaxis, :]
pred_bboxes = pred[:, :4]
scale_factor = np.tile(scale_factor[i][::-1], (1, 2))
pred_bboxes /= scale_factor
bbox = np.concatenate(
[
pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis],
pred_bboxes
],
axis=-1)
bboxs.append(bbox)
box_num = bbox.shape[0]
box_nums.append(box_num)
bboxs = np.concatenate(bboxs, axis=0)
box_nums = np.array(box_nums)
return {'bbox': bboxs, 'bbox_num': box_nums}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.quant import quant_post_static
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='ptq_out',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--algo', type=str, default='KL', help="post quant algo.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_post_static(
executor=exe,
model_dir=global_config["model_dir"],
quantize_model_path=FLAGS.save_dir,
data_loader=train_loader,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
batch_size=32,
batch_nums=10,
algo=FLAGS.algo,
hist_percent=0.999,
is_full_quantize=False,
bias_correction=False,
onnx_format=False)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from post_process import YOLOv7PostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--eval', type=bool, default=False, help="whether to run evaluation.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOv7PostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
return map_res['bbox'][0]
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
if 'Evaluation' in global_config.keys() and global_config[
'Evaluation'] and paddle.distributed.get_rank() == 0:
eval_func = eval_function
dataset = reader_cfg['EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
reader_cfg['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
else:
eval_func = None
ac = AutoCompression(
model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir,
config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func)
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册