From c3032f3a0445add1bbf5a4acfa0651c1c5404137 Mon Sep 17 00:00:00 2001
From: Guanghua Yu <742925032@qq.com>
Date: Mon, 25 Jul 2022 14:00:20 +0800
Subject: [PATCH] update ppyoloe act demo (#1307)
---
example/auto_compression/README.md | 2 +-
example/auto_compression/detection/README.md | 52 ++-
.../detection/configs/ppyoloe_l_qat_dis.yaml | 1 +
.../detection/configs/ppyoloe_s_qat_dis.yaml | 33 ++
.../cpp_infer_ppyoloe/CMakeLists.txt | 263 +++++++++++++++
.../detection/cpp_infer_ppyoloe/README.md | 51 +++
.../detection/cpp_infer_ppyoloe/compile.sh | 37 +++
.../detection/cpp_infer_ppyoloe/trt_run.cc | 116 +++++++
example/auto_compression/detection/eval.py | 5 +
.../detection/paddle_trt_infer.py | 311 ++++++++++++++++++
.../detection/post_process.py | 157 +++++++++
.../auto_compression/detection/post_quant.py | 103 ++++++
example/auto_compression/detection/run.py | 5 +
13 files changed, 1123 insertions(+), 13 deletions(-)
create mode 100644 example/auto_compression/detection/configs/ppyoloe_s_qat_dis.yaml
create mode 100644 example/auto_compression/detection/cpp_infer_ppyoloe/CMakeLists.txt
create mode 100644 example/auto_compression/detection/cpp_infer_ppyoloe/README.md
create mode 100644 example/auto_compression/detection/cpp_infer_ppyoloe/compile.sh
create mode 100644 example/auto_compression/detection/cpp_infer_ppyoloe/trt_run.cc
create mode 100644 example/auto_compression/detection/paddle_trt_infer.py
create mode 100644 example/auto_compression/detection/post_process.py
create mode 100644 example/auto_compression/detection/post_quant.py
diff --git a/example/auto_compression/README.md b/example/auto_compression/README.md
index e907908b..055ae3e7 100644
--- a/example/auto_compression/README.md
+++ b/example/auto_compression/README.md
@@ -85,7 +85,7 @@ ACT相比传统的模型压缩方法,
| [目标检测](./pytorch_yolov5) | YOLOv5s
(PyTorch) | 37.40 | 36.9 | 5.95 | 1.87 | **3.18** | NVIDIA Tesla T4 |
| [目标检测](./pytorch_yolov6) | YOLOv6s
(PyTorch) | 42.4 | 41.3 | 9.06 | 1.83 | **4.95** | NVIDIA Tesla T4 |
| [目标检测](./pytorch_yolov7) | YOLOv7
(PyTorch) | 51.1 | 50.8 | 26.84 | 4.55 | **5.89** | NVIDIA Tesla T4 |
-| [目标检测](./detection) | PP-YOLOE-l | 50.9 | 50.6 | 11.2 | 6.7 | **1.67** | NVIDIA Tesla V100 |
+| [目标检测](./detection) | PP-YOLOE-s | 43.1 | 42.6 | 6.51 | 2.12 | **3.07** | NVIDIA Tesla T4 |
| [图像分类](./image_classification) | MobileNetV1
(TensorFlow) | 71.0 | 70.22 | 30.45 | 15.86 | **1.92** | SDMM865(骁龙865) |
- 备注:目标检测精度指标为mAP(0.5:0.95)精度测量结果。图像分割精度指标为IoU精度测量结果。
diff --git a/example/auto_compression/detection/README.md b/example/auto_compression/detection/README.md
index 6c35915a..58f10b27 100644
--- a/example/auto_compression/detection/README.md
+++ b/example/auto_compression/detection/README.md
@@ -13,21 +13,21 @@
- [5.FAQ](5FAQ)
## 1. 简介
-本示例将以目标检测模型PP-YOLOE-l为例,介绍如何使用PaddleDetection中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化蒸馏。
+本示例将以目标检测模型PP-YOLOE为例,介绍如何使用PaddleDetection中Inference部署模型进行自动压缩。本示例使用的自动压缩策略为量化蒸馏。
## 2.Benchmark
### PP-YOLOE
-| 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 |
-| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
-| PP-YOLOE-l | Base模型 | 640*640 | 50.9 | 11.2 | 7.7ms | - | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_coco.tar) |
-| PP-YOLOE-l | 量化蒸馏训练 | 640*640 | 50.6 | - | - | 6.7ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_l_300e_coco_quant.tar) |
-
-- mAP的指标均在COCO val2017数据集中评测得到。
-- PP-YOLOE模型在Tesla V100的GPU环境下测试,并且开启TensorRT,batch_size=1,包含NMS,测试脚本是[benchmark demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python)。
+| 模型 | Base mAP | 离线量化mAP | ACT量化mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 量化模型 |
+| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
+| PP-YOLOE-l | 50.9 | - | 50.6 | 11.2ms | 7.7ms | **6.7ms** | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_l_300e_coco_quant.tar) |
+| PP-YOLOE-s | 43.1 | 26.2 | 42.6 | 6.51ms | 2.77ms | **2.12ms** | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/ppyoloe_s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_s_quant.tar) |
+- mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
+- PP-YOLOE-l模型在Tesla V100的GPU环境下测试,并且开启TensorRT,batch_size=1,包含NMS,测试脚本是[benchmark demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python)。
+- PP-YOLOE-s模型在Tesla T4,TensorRT 8.4.1,CUDA 11.2,batch_size=1,不包含NMS,测试脚本是[cpp_infer_ppyoloe](./cpp_infer_ppyoloe)。
## 3. 自动压缩流程
#### 3.1 准备环境
@@ -56,7 +56,6 @@ pip install paddledet
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
-
#### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
@@ -79,6 +78,7 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- 导出预测模型
+PPYOLOE-l模型,包含NMS:如快速体验,可直接下载[PP-YOLOE-l导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_l_300e_coco.tar)
```shell
python tools/export_model.py \
-c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml \
@@ -86,7 +86,13 @@ python tools/export_model.py \
trt=True \
```
-**注意**:PP-YOLOE导出时设置`trt=True`旨在优化在TensorRT上的性能,如果没有使用TensorRT,或者其他模型都不需要设置`trt=True`。如果想快速体验,可以直接下载[PP-YOLOE-l导出模型](https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_coco.tar)。
+PPYOLOE-s模型,不包含NMS:如快速体验,可直接下载[PP-YOLOE-s导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_s_300e_coco.tar)
+```shell
+python tools/export_model.py \
+ -c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \
+ -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams \
+ trt=True exclude_nms=True \
+```
#### 3.4 自动压缩并产出模型
@@ -117,7 +123,29 @@ python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml
## 4.预测部署
-可以参考[PaddleDetection部署教程](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy):
-- GPU上量化模型开启TensorRT并设置trt_int8模式进行部署。
+- 如果模型包含NMS,可以参考[PaddleDetection部署教程](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy),GPU上量化模型开启TensorRT并设置trt_int8模式进行部署。
+
+- 模型为PPYOLOE,同时不包含NMS,使用以下预测demo进行部署:
+ - Paddle-TensorRT C++部署
+
+ 进入[cpp_infer](./cpp_infer_ppyoloe)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer_ppyoloe/README.md)进行准备环境及编译,然后开始测试:
+ ```shell
+ # 编译
+ bash complie.sh
+ # 执行
+ ./build/trt_run --model_file ppyoloe_s_quant/model.pdmodel --params_file ppyoloe_s_quant/model.pdiparams --run_mode=trt_int8
+ ```
+
+ - Paddle-TensorRT Python部署:
+
+ 首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署:
+ ```shell
+ python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8
+ ```
## 5.FAQ
+
+- 如果想测试离线量化模型精度,可执行:
+```shell
+python post_quant.py --config_path=./configs/ppyoloe_s_qat_dis.yaml
+```
diff --git a/example/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml b/example/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml
index 0b28ef89..cd39981c 100644
--- a/example/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml
+++ b/example/auto_compression/detection/configs/ppyoloe_l_qat_dis.yaml
@@ -2,6 +2,7 @@
Global:
reader_config: configs/yolo_reader.yml
input_list: ['image', 'scale_factor']
+ arch: YOLO
Evaluation: True
model_dir: ./ppyoloe_crn_l_300e_coco
model_filename: model.pdmodel
diff --git a/example/auto_compression/detection/configs/ppyoloe_s_qat_dis.yaml b/example/auto_compression/detection/configs/ppyoloe_s_qat_dis.yaml
new file mode 100644
index 00000000..1efd1175
--- /dev/null
+++ b/example/auto_compression/detection/configs/ppyoloe_s_qat_dis.yaml
@@ -0,0 +1,33 @@
+
+Global:
+ reader_config: configs/yolo_reader.yml
+ input_list: ['image']
+ arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE
+ Evaluation: True
+ model_dir: ./ppyoloe_crn_s_300e_coco
+ model_filename: model.pdmodel
+ params_filename: model.pdiparams
+
+Distillation:
+ alpha: 1.0
+ loss: soft_label
+
+Quantization:
+ use_pact: true
+ activation_quantize_type: 'moving_average_abs_max'
+ quantize_op_types:
+ - conv2d
+ - depthwise_conv2d
+
+TrainConfig:
+ train_iter: 5000
+ eval_iter: 1000
+ learning_rate:
+ type: CosineAnnealingDecay
+ learning_rate: 0.00003
+ T_max: 6000
+ optimizer_builder:
+ optimizer:
+ type: SGD
+ weight_decay: 4.0e-05
+
diff --git a/example/auto_compression/detection/cpp_infer_ppyoloe/CMakeLists.txt b/example/auto_compression/detection/cpp_infer_ppyoloe/CMakeLists.txt
new file mode 100644
index 00000000..d5307c65
--- /dev/null
+++ b/example/auto_compression/detection/cpp_infer_ppyoloe/CMakeLists.txt
@@ -0,0 +1,263 @@
+cmake_minimum_required(VERSION 3.0)
+project(cpp_inference_demo CXX C)
+option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
+option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
+option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON)
+option(USE_TENSORRT "Compile demo with TensorRT." OFF)
+option(WITH_ROCM "Compile demo with rocm." OFF)
+option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF)
+option(WITH_ARM "Compile demo with ARM" OFF)
+option(WITH_MIPS "Compile demo with MIPS" OFF)
+option(WITH_SW "Compile demo with SW" OFF)
+option(WITH_XPU "Compile demow ith xpu" OFF)
+option(WITH_NPU "Compile demow ith npu" OFF)
+
+if(NOT WITH_STATIC_LIB)
+ add_definitions("-DPADDLE_WITH_SHARED_LIB")
+else()
+ # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode.
+ # Set it to empty in static library mode to avoid compilation issues.
+ add_definitions("/DPD_INFER_DECL=")
+endif()
+
+macro(safe_set_static_flag)
+ foreach(flag_var
+ CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
+ CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
+ if(${flag_var} MATCHES "/MD")
+ string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
+ endif(${flag_var} MATCHES "/MD")
+ endforeach(flag_var)
+endmacro()
+
+if(NOT DEFINED PADDLE_LIB)
+ message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
+endif()
+if(NOT DEFINED DEMO_NAME)
+ message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name")
+endif()
+
+include_directories("${PADDLE_LIB}/")
+set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/include")
+
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
+link_directories("${PADDLE_LIB}/paddle/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib")
+
+if (WIN32)
+ add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
+ option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
+ if (MSVC_STATIC_CRT)
+ if (WITH_MKL)
+ set(FLAG_OPENMP "/openmp")
+ endif()
+ set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
+ set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
+ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
+ safe_set_static_flag()
+ if (WITH_STATIC_LIB)
+ add_definitions(-DSTATIC_LIB)
+ endif()
+ endif()
+else()
+ if(WITH_MKL)
+ set(FLAG_OPENMP "-fopenmp")
+ endif()
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}")
+endif()
+
+if(WITH_GPU)
+ if(NOT WIN32)
+ include_directories("/usr/local/cuda/include")
+ if(CUDA_LIB STREQUAL "")
+ set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
+ endif()
+ else()
+ include_directories("C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include")
+ if(CUDA_LIB STREQUAL "")
+ set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64")
+ endif()
+ endif(NOT WIN32)
+endif()
+
+if (USE_TENSORRT AND WITH_GPU)
+ set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library")
+ if("${TENSORRT_ROOT}" STREQUAL "")
+ message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ")
+ endif()
+ set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include)
+ set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib)
+ file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS)
+ string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
+ "${TENSORRT_VERSION_FILE_CONTENTS}")
+ if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
+ file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS)
+ string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
+ "${TENSORRT_VERSION_FILE_CONTENTS}")
+ endif()
+ if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
+ message(SEND_ERROR "Failed to detect TensorRT version.")
+ endif()
+ string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1"
+ TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}")
+ message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
+ "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
+ include_directories("${TENSORRT_INCLUDE_DIR}")
+ link_directories("${TENSORRT_LIB_DIR}")
+endif()
+
+if(WITH_MKL)
+ set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
+ include_directories("${MATH_LIB_PATH}/include")
+ if(WIN32)
+ set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX})
+ else()
+ set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+ endif()
+ set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
+ if(EXISTS ${MKLDNN_PATH})
+ include_directories("${MKLDNN_PATH}/include")
+ if(WIN32)
+ set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
+ else(WIN32)
+ set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+ endif(WIN32)
+ endif()
+elseif((NOT WITH_MIPS) AND (NOT WITH_SW))
+ set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas")
+ include_directories("${OPENBLAS_LIB_PATH}/include/openblas")
+ if(WIN32)
+ set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX})
+ else()
+ set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
+ endif()
+endif()
+
+if(WITH_STATIC_LIB)
+ set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
+else()
+ if(WIN32)
+ set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
+ else()
+ set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
+ endif()
+endif()
+
+if (WITH_ONNXRUNTIME)
+ if(WIN32)
+ set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib paddle2onnx)
+ elseif(APPLE)
+ set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx)
+ else()
+ set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx)
+ endif()
+endif()
+
+if (NOT WIN32)
+ set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+ set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf xxhash cryptopp
+ ${EXTERNAL_LIB})
+else()
+ set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags_static libprotobuf xxhash cryptopp-static ${EXTERNAL_LIB})
+ set(DEPS ${DEPS} shlwapi.lib)
+endif(NOT WIN32)
+
+if(WITH_GPU)
+ if(NOT WIN32)
+ if (USE_TENSORRT)
+ set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX})
+ endif()
+ set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
+ else()
+ if(USE_TENSORRT)
+ set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
+ if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
+ set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX})
+ endif()
+ endif()
+ set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
+ set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} )
+ set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} )
+ endif()
+endif()
+
+if(WITH_ROCM AND NOT WIN32)
+ set(DEPS ${DEPS} ${ROCM_LIB}/libamdhip64${CMAKE_SHARED_LIBRARY_SUFFIX})
+endif()
+
+if(WITH_XPU AND NOT WIN32)
+ set(XPU_INSTALL_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}xpu")
+ set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX})
+endif()
+
+if(WITH_NPU AND NOT WIN32)
+ set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX})
+ set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX})
+endif()
+
+add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
+target_link_libraries(${DEMO_NAME} ${DEPS})
+if(WIN32)
+ if(USE_TENSORRT)
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+ COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+ )
+ if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE})
+ endif()
+ endif()
+ if(WITH_MKL)
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release
+ COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release
+ COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release
+ )
+ else()
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release
+ )
+ endif()
+ if(WITH_ONNXRUNTIME)
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll
+ ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+ COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll
+ ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+ )
+ endif()
+ if(NOT WITH_STATIC_LIB)
+ add_custom_command(TARGET ${DEMO_NAME} POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}
+ )
+ endif()
+endif()
diff --git a/example/auto_compression/detection/cpp_infer_ppyoloe/README.md b/example/auto_compression/detection/cpp_infer_ppyoloe/README.md
new file mode 100644
index 00000000..8b2a2eba
--- /dev/null
+++ b/example/auto_compression/detection/cpp_infer_ppyoloe/README.md
@@ -0,0 +1,51 @@
+# PPYOLOE TensorRT Benchmark测试(Linux)
+
+## 环境准备
+
+- CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。
+
+- TensorRT:可通过NVIDIA官网下载[TensorRT 8.4.1.5](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz)或其他版本安装包。
+
+- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件。
+
+## 编译可执行程序
+
+- (1)修改`compile.sh`中依赖库路径,主要是以下内容:
+```shell
+# Paddle Inference预测库路径
+LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/
+# CUDNN路径
+CUDNN_LIB=/usr/lib/x86_64-linux-gnu/
+# CUDA路径
+CUDA_LIB=/usr/local/cuda/lib64
+# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹
+TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/
+```
+
+## 测试
+
+- FP32
+```
+./build/trt_run --model_file ppyoloe_crn_s_300e_coco/model.pdmodel --params_file ppyoloe_crn_s_300e_coco/model.pdiparams --run_mode=trt_fp32
+```
+
+- FP16
+```
+./build/trt_run --model_file ppyoloe_crn_s_300e_coco/model.pdmodel --params_file ppyoloe_crn_s_300e_coco/model.pdiparams --run_mode=trt_fp16
+```
+
+- INT8
+```
+./build/trt_run --model_file ppyoloe_s_quant/model.pdmodel --params_file ppyoloe_s_quant/model.pdiparams --run_mode=trt_int8
+```
+
+## 性能对比
+
+| 预测库 | 模型 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) |
+| :--------: | :--------: |:-------- |:--------: | :---------------------: |
+| Paddle TensorRT | PPYOLOE-s | 6.51ms | 2.77ms | 2.12ms |
+| TensorRT | PPYOLOE-s | 6.61ms | 2.90ms | 2.31ms |
+
+环境:
+- Tesla T4,TensorRT 8.4.1,CUDA 11.2
+- batch_size=1
diff --git a/example/auto_compression/detection/cpp_infer_ppyoloe/compile.sh b/example/auto_compression/detection/cpp_infer_ppyoloe/compile.sh
new file mode 100644
index 00000000..afff924b
--- /dev/null
+++ b/example/auto_compression/detection/cpp_infer_ppyoloe/compile.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+set +x
+set -e
+
+work_path=$(dirname $(readlink -f $0))
+
+mkdir -p build
+cd build
+rm -rf *
+
+DEMO_NAME=trt_run
+
+WITH_MKL=ON
+WITH_GPU=ON
+USE_TENSORRT=ON
+
+LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/
+CUDNN_LIB=/usr/lib/x86_64-linux-gnu/
+CUDA_LIB=/usr/local/cuda/lib64
+TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/
+
+WITH_ROCM=OFF
+ROCM_LIB=/opt/rocm/lib
+
+cmake .. -DPADDLE_LIB=${LIB_DIR} \
+ -DWITH_MKL=${WITH_MKL} \
+ -DDEMO_NAME=${DEMO_NAME} \
+ -DWITH_GPU=${WITH_GPU} \
+ -DWITH_STATIC_LIB=OFF \
+ -DUSE_TENSORRT=${USE_TENSORRT} \
+ -DWITH_ROCM=${WITH_ROCM} \
+ -DROCM_LIB=${ROCM_LIB} \
+ -DCUDNN_LIB=${CUDNN_LIB} \
+ -DCUDA_LIB=${CUDA_LIB} \
+ -DTENSORRT_ROOT=${TENSORRT_ROOT}
+
+make -j
diff --git a/example/auto_compression/detection/cpp_infer_ppyoloe/trt_run.cc b/example/auto_compression/detection/cpp_infer_ppyoloe/trt_run.cc
new file mode 100644
index 00000000..fc0ac436
--- /dev/null
+++ b/example/auto_compression/detection/cpp_infer_ppyoloe/trt_run.cc
@@ -0,0 +1,116 @@
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include "paddle/include/paddle_inference_api.h"
+#include "paddle/include/experimental/phi/common/float16.h"
+
+using paddle_infer::Config;
+using paddle_infer::Predictor;
+using paddle_infer::CreatePredictor;
+using paddle_infer::PrecisionType;
+using phi::dtype::float16;
+
+DEFINE_string(model_dir, "", "Directory of the inference model.");
+DEFINE_string(model_file, "", "Path of the inference model file.");
+DEFINE_string(params_file, "", "Path of the inference params file.");
+DEFINE_string(run_mode, "trt_fp32", "run_mode which can be: trt_fp32, trt_fp16 and trt_int8");
+DEFINE_int32(batch_size, 1, "Batch size.");
+DEFINE_int32(gpu_id, 0, "GPU card ID num.");
+DEFINE_int32(trt_min_subgraph_size, 3, "tensorrt min_subgraph_size");
+DEFINE_int32(warmup, 50, "warmup");
+DEFINE_int32(repeats, 1000, "repeats");
+
+using Time = decltype(std::chrono::high_resolution_clock::now());
+Time time() { return std::chrono::high_resolution_clock::now(); };
+double time_diff(Time t1, Time t2) {
+ typedef std::chrono::microseconds ms;
+ auto diff = t2 - t1;
+ ms counter = std::chrono::duration_cast(diff);
+ return counter.count() / 1000.0;
+}
+
+std::shared_ptr InitPredictor() {
+ Config config;
+ std::string model_path;
+ if (FLAGS_model_dir != "") {
+ config.SetModel(FLAGS_model_dir);
+ model_path = FLAGS_model_dir.substr(0, FLAGS_model_dir.find_last_of("/"));
+ } else {
+ config.SetModel(FLAGS_model_file, FLAGS_params_file);
+ model_path = FLAGS_model_file.substr(0, FLAGS_model_file.find_last_of("/"));
+ }
+ // enable tune
+ std::cout << "model_path: " << model_path << std::endl;
+ config.EnableUseGpu(256, FLAGS_gpu_id);
+ if (FLAGS_run_mode == "trt_fp32") {
+ config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
+ PrecisionType::kFloat32, false, false);
+ } else if (FLAGS_run_mode == "trt_fp16") {
+ config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
+ PrecisionType::kHalf, false, false);
+ } else if (FLAGS_run_mode == "trt_int8") {
+ config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size,
+ PrecisionType::kInt8, false, false);
+ }
+ config.EnableMemoryOptim();
+ config.SwitchIrOptim(true);
+ return CreatePredictor(config);
+}
+
+template
+void run(Predictor *predictor, const std::vector &input,
+ const std::vector &input_shape, type* out_data, std::vector out_shape) {
+
+ // prepare input
+ int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
+ std::multiplies());
+
+ auto input_names = predictor->GetInputNames();
+ auto input_t = predictor->GetInputHandle(input_names[0]);
+ input_t->Reshape(input_shape);
+ input_t->CopyFromCpu(input.data());
+
+ for (int i = 0; i < FLAGS_warmup; ++i)
+ CHECK(predictor->Run());
+
+ auto st = time();
+ for (int i = 0; i < FLAGS_repeats; ++i) {
+ auto input_names = predictor->GetInputNames();
+ auto input_t = predictor->GetInputHandle(input_names[0]);
+ input_t->Reshape(input_shape);
+ input_t->CopyFromCpu(input.data());
+
+ CHECK(predictor->Run());
+
+ auto output_names = predictor->GetOutputNames();
+ auto output_t = predictor->GetOutputHandle(output_names[0]);
+ std::vector output_shape = output_t->shape();
+ output_t -> ShareExternalData(out_data, out_shape, paddle_infer::PlaceType::kGPU);
+ }
+
+ LOG(INFO) << "[" << FLAGS_run_mode << " bs-" << FLAGS_batch_size << " ] run avg time is " << time_diff(st, time()) / FLAGS_repeats
+ << " ms";
+}
+
+int main(int argc, char *argv[]) {
+ google::ParseCommandLineFlags(&argc, &argv, true);
+ auto predictor = InitPredictor();
+ std::vector input_shape = {FLAGS_batch_size, 3, 640, 640};
+ // float16
+ using dtype = float16;
+ std::vector input_data(FLAGS_batch_size * 3 * 640 * 640, dtype(1.0));
+
+ dtype *out_data;
+ int out_data_size = FLAGS_batch_size * 8400 * 84;
+ cudaHostAlloc((void**)&out_data, sizeof(float) * out_data_size, cudaHostAllocMapped);
+
+ std::vector out_shape{ FLAGS_batch_size, 1, 8400, 84};
+ run(predictor.get(), input_data, input_shape, out_data, out_shape);
+ return 0;
+}
diff --git a/example/auto_compression/detection/eval.py b/example/auto_compression/detection/eval.py
index 3a723653..d80f3cfe 100644
--- a/example/auto_compression/detection/eval.py
+++ b/example/auto_compression/detection/eval.py
@@ -22,6 +22,7 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from keypoint_utils import keypoint_post_process
+from post_process import PPYOLOEPostProcess
def argsparser():
@@ -103,6 +104,10 @@ def eval():
if 'arch' in global_config and global_config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe, val_program,
fetch_targets, outs)
+ if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
+ postprocess = PPYOLOEPostProcess(
+ score_threshold=0.01, nms_threshold=0.6)
+ res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
diff --git a/example/auto_compression/detection/paddle_trt_infer.py b/example/auto_compression/detection/paddle_trt_infer.py
new file mode 100644
index 00000000..6f62d2af
--- /dev/null
+++ b/example/auto_compression/detection/paddle_trt_infer.py
@@ -0,0 +1,311 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import numpy as np
+import argparse
+import time
+
+from paddle.inference import Config
+from paddle.inference import create_predictor
+
+from post_process import PPYOLOEPostProcess
+
+CLASS_LABEL = [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
+ 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
+ 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
+ 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
+ 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
+ 'hair drier', 'toothbrush'
+]
+
+
+def generate_scale(im, target_shape, keep_ratio=True):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ origin_shape = im.shape[:2]
+ if keep_ratio:
+ im_size_min = np.min(origin_shape)
+ im_size_max = np.max(origin_shape)
+ target_size_min = np.min(target_shape)
+ target_size_max = np.max(target_shape)
+ im_scale = float(target_size_min) / float(im_size_min)
+ if np.round(im_scale * im_size_max) > target_size_max:
+ im_scale = float(target_size_max) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+ else:
+ resize_h, resize_w = target_shape
+ im_scale_y = resize_h / float(origin_shape[0])
+ im_scale_x = resize_w / float(origin_shape[1])
+ return im_scale_y, im_scale_x
+
+
+def image_preprocess(img_path, target_shape):
+ img = cv2.imread(img_path)
+ im_scale_y, im_scale_x = generate_scale(img, target_shape, keep_ratio=False)
+ img = cv2.resize(
+ img, (target_shape[0], target_shape[0]),
+ interpolation=cv2.INTER_LANCZOS4)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, [2, 0, 1]) / 255
+ img = np.expand_dims(img, 0)
+ img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
+ img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
+ img -= img_mean
+ img /= img_std
+ scale_factor = np.array([[im_scale_y, im_scale_x]])
+ return img.astype(np.float32), scale_factor
+
+
+def get_color_map_list(num_classes):
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
+ return color_map
+
+
+def draw_box(image_file, results, class_label, threshold=0.5):
+ srcimg = cv2.imread(image_file, 1)
+ for i in range(len(results)):
+ color_list = get_color_map_list(len(class_label))
+ clsid2color = {}
+ classid, conf = int(results[i, 0]), results[i, 1]
+ if conf < threshold:
+ continue
+ xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int(
+ results[i, 4]), int(results[i, 5])
+
+ if classid not in clsid2color:
+ clsid2color[classid] = color_list[classid]
+ color = tuple(clsid2color[classid])
+
+ cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2)
+ print(class_label[classid] + ': ' + str(round(conf, 3)))
+ cv2.putText(
+ srcimg,
+ class_label[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.8, (0, 255, 0),
+ thickness=2)
+ return srcimg
+
+
+def load_predictor(model_dir,
+ run_mode='paddle',
+ batch_size=1,
+ device='CPU',
+ min_subgraph_size=3,
+ use_dynamic_shape=False,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ delete_shuffle_pass=False):
+ """set AnalysisConfig, generate AnalysisPredictor
+ Args:
+ model_dir (str): root path of __model__ and __params__
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
+ use_dynamic_shape (bool): use dynamic shape or not
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
+ Used by action model.
+ Returns:
+ predictor (PaddlePredictor): AnalysisPredictor
+ Raises:
+ ValueError: predict by TensorRT need device == 'GPU'.
+ """
+ if device != 'GPU' and run_mode != 'paddle':
+ raise ValueError(
+ "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
+ .format(run_mode, device))
+ config = Config(
+ os.path.join(model_dir, 'model.pdmodel'),
+ os.path.join(model_dir, 'model.pdiparams'))
+ if device == 'GPU':
+ # initial GPU memory(M), device ID
+ config.enable_use_gpu(200, 0)
+ # optimize graph and fuse op
+ config.switch_ir_optim(True)
+ elif device == 'XPU':
+ config.enable_lite_engine()
+ config.enable_xpu(10 * 1024 * 1024)
+ else:
+ config.disable_gpu()
+ config.set_cpu_math_library_num_threads(cpu_threads)
+ if enable_mkldnn:
+ try:
+ # cache 10 different shapes for mkldnn to avoid memory leak
+ config.set_mkldnn_cache_capacity(10)
+ config.enable_mkldnn()
+ if enable_mkldnn_bfloat16:
+ config.enable_mkldnn_bfloat16()
+ except Exception as e:
+ print(
+ "The current environment does not support `mkldnn`, so disable mkldnn."
+ )
+ pass
+
+ precision_map = {
+ 'trt_int8': Config.Precision.Int8,
+ 'trt_fp32': Config.Precision.Float32,
+ 'trt_fp16': Config.Precision.Half
+ }
+ if run_mode in precision_map.keys():
+ config.enable_tensorrt_engine(
+ workspace_size=(1 << 25) * batch_size,
+ max_batch_size=batch_size,
+ min_subgraph_size=min_subgraph_size,
+ precision_mode=precision_map[run_mode],
+ use_static=False,
+ use_calib_mode=trt_calib_mode)
+
+ if use_dynamic_shape:
+ min_input_shape = {
+ 'image': [batch_size, 3, trt_min_shape, trt_min_shape]
+ }
+ max_input_shape = {
+ 'image': [batch_size, 3, trt_max_shape, trt_max_shape]
+ }
+ opt_input_shape = {
+ 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
+ }
+ config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
+ opt_input_shape)
+ print('trt set dynamic shape done!')
+
+ # disable print log when predict
+ config.disable_glog_info()
+ # enable shared memory
+ config.enable_memory_optim()
+ # disable feed, fetch OP, needed by zero_copy_run
+ config.switch_use_feed_fetch_ops(False)
+ if delete_shuffle_pass:
+ config.delete_pass("shuffle_channel_detect_pass")
+ predictor = create_predictor(config)
+ return predictor
+
+
+def predict_image(predictor,
+ image_file,
+ image_shape=[640, 640],
+ warmup=1,
+ repeats=1,
+ threshold=0.5,
+ arch='YOLOv5'):
+ img, scale_factor = image_preprocess(image_file, image_shape)
+ inputs = {}
+ inputs['image'] = img
+ input_names = predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = predictor.get_input_handle(input_names[i])
+ input_tensor.copy_from_cpu(inputs[input_names[i]])
+
+ for i in range(warmup):
+ predictor.run()
+
+ np_boxes = None
+ predict_time = 0.
+ time_min = float("inf")
+ time_max = float('-inf')
+ for i in range(repeats):
+ start_time = time.time()
+ predictor.run()
+ output_names = predictor.get_output_names()
+ boxes_tensor = predictor.get_output_handle(output_names[0])
+ np_boxes = boxes_tensor.copy_to_cpu()
+ end_time = time.time()
+ timed = end_time - start_time
+ time_min = min(time_min, timed)
+ time_max = max(time_max, timed)
+ predict_time += timed
+
+ time_avg = predict_time / repeats
+ print('Inference time(ms): min={}, max={}, avg={}'.format(
+ round(time_min * 1000, 2),
+ round(time_max * 1000, 1), round(time_avg * 1000, 1)))
+ postprocess = PPYOLOEPostProcess(score_threshold=0.3, nms_threshold=0.6)
+ res = postprocess(np_boxes, scale_factor)
+ res_img = draw_box(
+ image_file, res['bbox'], CLASS_LABEL, threshold=threshold)
+ cv2.imwrite('result.jpg', res_img)
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--image_file', type=str, default=None, help="image path")
+ parser.add_argument(
+ '--model_path', type=str, help="inference model filepath")
+ parser.add_argument(
+ '--benchmark',
+ type=bool,
+ default=False,
+ help="Whether run benchmark or not.")
+ parser.add_argument(
+ '--run_mode',
+ type=str,
+ default='paddle',
+ help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='GPU',
+ help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU"
+ )
+ parser.add_argument('--img_shape', type=int, default=640, help="input_size")
+ args = parser.parse_args()
+
+ predictor = load_predictor(
+ args.model_path, run_mode=args.run_mode, device=args.device)
+ warmup, repeats = 1, 1
+ if args.benchmark:
+ warmup, repeats = 50, 100
+ predict_image(
+ predictor,
+ args.image_file,
+ image_shape=[args.img_shape, args.img_shape],
+ warmup=warmup,
+ repeats=repeats)
diff --git a/example/auto_compression/detection/post_process.py b/example/auto_compression/detection/post_process.py
new file mode 100644
index 00000000..eea2f019
--- /dev/null
+++ b/example/auto_compression/detection/post_process.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import cv2
+
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ """
+ Args:
+ box_scores (N, 5): boxes in corner-form and probabilities.
+ iou_threshold: intersection over union threshold.
+ top_k: keep top_k results. If k <= 0, keep all the results.
+ candidate_size: only consider the candidates with the highest scores.
+ Returns:
+ picked: a list of indexes of the kept boxes
+ """
+ scores = box_scores[:, -1]
+ boxes = box_scores[:, :-1]
+ picked = []
+ indexes = np.argsort(scores)
+ indexes = indexes[-candidate_size:]
+ while len(indexes) > 0:
+ current = indexes[-1]
+ picked.append(current)
+ if 0 < top_k == len(picked) or len(indexes) == 1:
+ break
+ current_box = boxes[current, :]
+ indexes = indexes[:-1]
+ rest_boxes = boxes[indexes, :]
+ iou = iou_of(
+ rest_boxes,
+ np.expand_dims(
+ current_box, axis=0), )
+ indexes = indexes[iou <= iou_threshold]
+
+ return box_scores[picked, :]
+
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+ """Return intersection-over-union (Jaccard index) of boxes.
+ Args:
+ boxes0 (N, 4): ground truth boxes.
+ boxes1 (N or 1, 4): predicted boxes.
+ eps: a small number to avoid 0 as denominator.
+ Returns:
+ iou (N): IoU values.
+ """
+ overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
+ overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
+
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+ return overlap_area / (area0 + area1 - overlap_area + eps)
+
+
+def area_of(left_top, right_bottom):
+ """Compute the areas of rectangles given two corners.
+ Args:
+ left_top (N, 2): left top corner.
+ right_bottom (N, 2): right bottom corner.
+ Returns:
+ area (N): return the area.
+ """
+ hw = np.clip(right_bottom - left_top, 0.0, None)
+ return hw[..., 0] * hw[..., 1]
+
+
+class PPYOLOEPostProcess(object):
+ """
+ Args:
+ input_shape (int): network input image size
+ scale_factor (float): scale factor of ori image
+ """
+
+ def __init__(self,
+ score_threshold=0.4,
+ nms_threshold=0.5,
+ nms_top_k=10000,
+ keep_top_k=300):
+ self.score_threshold = score_threshold
+ self.nms_threshold = nms_threshold
+ self.nms_top_k = nms_top_k
+ self.keep_top_k = keep_top_k
+
+ def _non_max_suppression(self, prediction, scale_factor):
+ batch_size = prediction.shape[0]
+ out_boxes_list = []
+ box_num_list = []
+ for batch_id in range(batch_size):
+ bboxes, confidences = prediction[batch_id][..., :4], prediction[
+ batch_id][..., 4:]
+ # nms
+ picked_box_probs = []
+ picked_labels = []
+ for class_index in range(0, confidences.shape[1]):
+ probs = confidences[:, class_index]
+ mask = probs > self.score_threshold
+ probs = probs[mask]
+ if probs.shape[0] == 0:
+ continue
+ subset_boxes = bboxes[mask, :]
+ box_probs = np.concatenate(
+ [subset_boxes, probs.reshape(-1, 1)], axis=1)
+ box_probs = hard_nms(
+ box_probs,
+ iou_threshold=self.nms_threshold,
+ top_k=self.nms_top_k)
+ picked_box_probs.append(box_probs)
+ picked_labels.extend([class_index] * box_probs.shape[0])
+
+ if len(picked_box_probs) == 0:
+ out_boxes_list.append(np.empty((0, 4)))
+
+ else:
+ picked_box_probs = np.concatenate(picked_box_probs)
+ # resize output boxes
+ picked_box_probs[:, 0] /= scale_factor[batch_id][1]
+ picked_box_probs[:, 2] /= scale_factor[batch_id][1]
+ picked_box_probs[:, 1] /= scale_factor[batch_id][0]
+ picked_box_probs[:, 3] /= scale_factor[batch_id][0]
+
+ # clas score box
+ out_box = np.concatenate(
+ [
+ np.expand_dims(
+ np.array(picked_labels), axis=-1), np.expand_dims(
+ picked_box_probs[:, 4], axis=-1),
+ picked_box_probs[:, :4]
+ ],
+ axis=1)
+ if out_box.shape[0] > self.keep_top_k:
+ out_box = out_box[out_box[:, 1].argsort()[::-1]
+ [:self.keep_top_k]]
+ out_boxes_list.append(out_box)
+ box_num_list.append(out_box.shape[0])
+
+ out_boxes_list = np.concatenate(out_boxes_list, axis=0)
+ box_num_list = np.array(box_num_list)
+ return out_boxes_list, box_num_list
+
+ def __call__(self, outs, scale_factor):
+ out_boxes_list, box_num_list = self._non_max_suppression(outs,
+ scale_factor)
+ return {'bbox': out_boxes_list, 'bbox_num': box_num_list}
diff --git a/example/auto_compression/detection/post_quant.py b/example/auto_compression/detection/post_quant.py
new file mode 100644
index 00000000..b3a70900
--- /dev/null
+++ b/example/auto_compression/detection/post_quant.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+import argparse
+import paddle
+from ppdet.core.workspace import load_config, merge_config
+from ppdet.core.workspace import create
+from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
+from paddleslim.quant import quant_post_static
+
+
+def argsparser():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ '--config_path',
+ type=str,
+ default=None,
+ help="path of compression strategy config.",
+ required=True)
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default='ptq_out',
+ help="directory to save compressed model.")
+ parser.add_argument(
+ '--devices',
+ type=str,
+ default='gpu',
+ help="which device used to compress.")
+ parser.add_argument(
+ '--algo', type=str, default='KL', help="post quant algo.")
+
+ return parser
+
+
+def reader_wrapper(reader, input_list):
+ def gen():
+ for data in reader:
+ in_dict = {}
+ if isinstance(input_list, list):
+ for input_name in input_list:
+ in_dict[input_name] = data[input_name]
+ elif isinstance(input_list, dict):
+ for input_name in input_list.keys():
+ in_dict[input_list[input_name]] = data[input_name]
+ yield in_dict
+
+ return gen
+
+
+def main():
+ global global_config
+ all_config = load_slim_config(FLAGS.config_path)
+ assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
+ global_config = all_config["Global"]
+ reader_cfg = load_config(global_config['reader_config'])
+
+ train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
+ reader_cfg['worker_num'],
+ return_list=True)
+ train_loader = reader_wrapper(train_loader, global_config['input_list'])
+
+ place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
+ exe = paddle.static.Executor(place)
+ quant_post_static(
+ executor=exe,
+ model_dir=global_config["model_dir"],
+ quantize_model_path=FLAGS.save_dir,
+ data_loader=train_loader,
+ model_filename=global_config["model_filename"],
+ params_filename=global_config["params_filename"],
+ batch_size=4,
+ batch_nums=64,
+ algo=FLAGS.algo,
+ hist_percent=0.999,
+ is_full_quantize=False,
+ bias_correction=False,
+ onnx_format=False)
+
+
+if __name__ == '__main__':
+ paddle.enable_static()
+ parser = argsparser()
+ FLAGS = parser.parse_args()
+
+ assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
+ paddle.set_device(FLAGS.devices)
+
+ main()
diff --git a/example/auto_compression/detection/run.py b/example/auto_compression/detection/run.py
index b7cc7505..a3c46d47 100644
--- a/example/auto_compression/detection/run.py
+++ b/example/auto_compression/detection/run.py
@@ -23,6 +23,7 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from keypoint_utils import keypoint_post_process
+from post_process import PPYOLOEPostProcess
def argsparser():
@@ -98,6 +99,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
res = keypoint_post_process(data, data_input, exe,
compiled_test_program, test_fetch_list,
outs)
+ if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
+ postprocess = PPYOLOEPostProcess(
+ score_threshold=0.01, nms_threshold=0.6)
+ res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
--
GitLab