From d5d279467dbef526f9f0b4e0cecf77ba7ae1951f Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 8 Jul 2022 12:28:25 +0800 Subject: [PATCH] update yolov5s act demo (#1274) --- .../auto_compression/pytorch_yolov5/README.md | 39 ++- .../configs/yolov5s_qat_dis.yaml | 2 +- .../pytorch_yolov5/cpp_infer/CMakeLists.txt | 263 ++++++++++++++++++ .../pytorch_yolov5/cpp_infer/README.md | 62 +++++ .../pytorch_yolov5/cpp_infer/compile.sh | 37 +++ .../pytorch_yolov5/cpp_infer/trt_run.cc | 116 ++++++++ .../auto_compression/pytorch_yolov5/eval.py | 10 +- .../pytorch_yolov5/post_quant.py | 104 +++++++ .../auto_compression/pytorch_yolov5/run.py | 10 - .../auto_compression/pytorch_yolov6/README.md | 5 + .../pytorch_yolov6/post_quant.py | 5 + 11 files changed, 624 insertions(+), 29 deletions(-) create mode 100644 example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt create mode 100644 example/auto_compression/pytorch_yolov5/cpp_infer/README.md create mode 100644 example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh create mode 100644 example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc create mode 100644 example/auto_compression/pytorch_yolov5/post_quant.py diff --git a/example/auto_compression/pytorch_yolov5/README.md b/example/auto_compression/pytorch_yolov5/README.md index 8afab4cf..16709408 100644 --- a/example/auto_compression/pytorch_yolov5/README.md +++ b/example/auto_compression/pytorch_yolov5/README.md @@ -1,4 +1,4 @@ -# 目标检测模型自动压缩示例 +# YOLOv5目标检测模型自动压缩示例 目录: - [1.简介](#1简介) @@ -22,12 +22,14 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| YOLOv5s | Base模型 | 640*640 | 37.4 | 7.8ms | 4.3ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar) | -| YOLOv5s | 量化+蒸馏 | 640*640 | 36.8 | - | - | 3.4ms | [config](./configs/yolov5s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | +| YOLOv5s | Base模型 | 640*640 | 37.4 | 5.95ms | 2.44ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar) | +| YOLOv5s | KL离线量化 | 640*640 | 36.0 | - | - | 1.87ms | - | - | +| YOLOv5s | 量化蒸馏训练 | 640*640 | **36.9** | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | + 说明: - mAP的指标均在COCO val2017数据集中评测得到。 -- YOLOv5s模型在Tesla T4的GPU环境下测试,并且开启TensorRT,测试脚本是[benchmark demo](./paddle_trt_infer.py)。 +- YOLOv5s模型在Tesla T4的GPU环境下开启TensorRT 8.4.1,batch_size=1, 测试脚本是[cpp_infer](./cpp_infer)。 ## 3. 自动压缩流程 @@ -67,18 +69,19 @@ pip install x2paddle sympy onnx 本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。 +如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中`EvalDataset`的`dataset_dir`字段为自己数据集路径即可。 #### 3.3 准备预测模型 (1)准备ONNX模型: 可通过[ultralytics/yolov5](https://github.com/ultralytics/yolov5) 官方的[导出教程](https://github.com/ultralytics/yolov5/issues/251)来准备ONNX模型。也可以下载准备好的[yolov5s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx)。 -``` +```shell python export.py --weights yolov5s.pt --include onnx ``` (2) 转换模型: -``` +```shell x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model cp -r pd_model/inference_model/ yolov5s_infer ``` @@ -112,15 +115,33 @@ export CUDA_VISIBLE_DEVICES=0 python eval.py --config_path=./configs/yolov5s_qat_dis.yaml ``` -**注意**:要测试的模型路径需要在配置文件中`model_dir`字段下进行修改指定。 +**注意**:如果要测试量化后的模型,模型路径需要在配置文件中`model_dir`字段下进行修改指定。 ## 4.预测部署 -- Paddle-TensorRT部署: -使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署: +#### Paddle-TensorRT C++部署 + +进入[cpp_infer](./cpp_infer)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer/README.md)进行准备环境及编译,然后开始测试: +```shell +# 编译 +bash complie.sh +# 执行 +./build/trt_run --model_file yolov5s_quant/model.pdmodel --params_file yolov5s_quant/model.pdiparams --run_mode=trt_int8 +``` + +#### Paddle-TensorRT Python部署: + +首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 + +然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署: ```shell python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8 ``` ## 5.FAQ + +- 如果想测试离线量化模型精度,可执行: +```shell +python post_quant.py --config_path=./configs/yolov5s_qat_dis.yaml +``` diff --git a/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml b/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml index de8d89e4..ef9bf8b7 100644 --- a/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml +++ b/example/auto_compression/pytorch_yolov5/configs/yolov5s_qat_dis.yaml @@ -10,7 +10,7 @@ Global: Distillation: alpha: 1.0 - loss: l2 + loss: soft_label Quantization: use_pact: true diff --git a/example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt b/example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt new file mode 100644 index 00000000..d5307c65 --- /dev/null +++ b/example/auto_compression/pytorch_yolov5/cpp_infer/CMakeLists.txt @@ -0,0 +1,263 @@ +cmake_minimum_required(VERSION 3.0) +project(cpp_inference_demo CXX C) +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) +option(WITH_ROCM "Compile demo with rocm." OFF) +option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) +option(WITH_ARM "Compile demo with ARM" OFF) +option(WITH_MIPS "Compile demo with MIPS" OFF) +option(WITH_SW "Compile demo with SW" OFF) +option(WITH_XPU "Compile demow ith xpu" OFF) +option(WITH_NPU "Compile demow ith npu" OFF) + +if(NOT WITH_STATIC_LIB) + add_definitions("-DPADDLE_WITH_SHARED_LIB") +else() + # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode. + # Set it to empty in static library mode to avoid compilation issues. + add_definitions("/DPD_INFER_DECL=") +endif() + +macro(safe_set_static_flag) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) +endmacro() + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") +endif() +if(NOT DEFINED DEMO_NAME) + message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name") +endif() + +include_directories("${PADDLE_LIB}/") +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/include") + +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") +link_directories("${PADDLE_LIB}/paddle/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib") + +if (WIN32) + add_definitions("/DGOOGLE_GLOG_DLL_DECL=") + option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) + if (MSVC_STATIC_CRT) + if (WITH_MKL) + set(FLAG_OPENMP "/openmp") + endif() + set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + safe_set_static_flag() + if (WITH_STATIC_LIB) + add_definitions(-DSTATIC_LIB) + endif() + endif() +else() + if(WITH_MKL) + set(FLAG_OPENMP "-fopenmp") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}") +endif() + +if(WITH_GPU) + if(NOT WIN32) + include_directories("/usr/local/cuda/include") + if(CUDA_LIB STREQUAL "") + set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") + endif() + else() + include_directories("C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include") + if(CUDA_LIB STREQUAL "") + set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") + endif() + endif(NOT WIN32) +endif() + +if (USE_TENSORRT AND WITH_GPU) + set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library") + if("${TENSORRT_ROOT}" STREQUAL "") + message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ") + endif() + set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include) + set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib) + file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") + file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + endif() + if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") + message(SEND_ERROR "Failed to detect TensorRT version.") + endif() + string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" + TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") + message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " + "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") +endif() + +if(WITH_MKL) + set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") + include_directories("${MATH_LIB_PATH}/include") + if(WIN32) + set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") + if(EXISTS ${MKLDNN_PATH}) + include_directories("${MKLDNN_PATH}/include") + if(WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) + else(WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) + endif(WIN32) + endif() +elseif((NOT WITH_MIPS) AND (NOT WITH_SW)) + set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas") + include_directories("${OPENBLAS_LIB_PATH}/include/openblas") + if(WIN32) + set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() +endif() + +if(WITH_STATIC_LIB) + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) +else() + if(WIN32) + set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() +endif() + +if (WITH_ONNXRUNTIME) + if(WIN32) + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib paddle2onnx) + elseif(APPLE) + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx) + else() + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx) + endif() +endif() + +if (NOT WIN32) + set(EXTERNAL_LIB "-lrt -ldl -lpthread") + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf xxhash cryptopp + ${EXTERNAL_LIB}) +else() + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags_static libprotobuf xxhash cryptopp-static ${EXTERNAL_LIB}) + set(DEPS ${DEPS} shlwapi.lib) +endif(NOT WIN32) + +if(WITH_GPU) + if(NOT WIN32) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) + else() + if(USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) + endif() +endif() + +if(WITH_ROCM AND NOT WIN32) + set(DEPS ${DEPS} ${ROCM_LIB}/libamdhip64${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +if(WITH_XPU AND NOT WIN32) + set(XPU_INSTALL_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}xpu") + set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +if(WITH_NPU AND NOT WIN32) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) +target_link_libraries(${DEMO_NAME} ${DEPS}) +if(WIN32) + if(USE_TENSORRT) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) + endif() + endif() + if(WITH_MKL) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release + ) + else() + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release + ) + endif() + if(WITH_ONNXRUNTIME) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + endif() + if(NOT WITH_STATIC_LIB) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + endif() +endif() diff --git a/example/auto_compression/pytorch_yolov5/cpp_infer/README.md b/example/auto_compression/pytorch_yolov5/cpp_infer/README.md new file mode 100644 index 00000000..9566728a --- /dev/null +++ b/example/auto_compression/pytorch_yolov5/cpp_infer/README.md @@ -0,0 +1,62 @@ +# YOLOv5 TensorRT Benchmark测试(Linux) + +## 环境准备 + +- CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。 + +- TensorRT:可通过NVIDIA官网下载[TensorRT 8.4.1.5](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz)或其他版本安装包。 + +- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件。 + +## 编译可执行程序 + +- (1)修改`compile.sh`中依赖库路径,主要是以下内容: +```shell +# Paddle Inference预测库路径 +LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +# CUDNN路径 +CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ +# CUDA路径 +CUDA_LIB=/usr/local/cuda/lib64 +# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹 +TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ +``` + +## Paddle TensorRT测试 + +- FP32 +``` +./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp32 +``` + +- FP16 +``` +./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp16 +``` + +- INT8 +``` +./build/trt_run --model_file yolov5s_quant/model.pdmodel --params_file yolov5s_quant/model.pdiparams --run_mode=trt_int8 +``` + +## 原生TensorRT测试 + +```shell +# FP32 +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw +# FP16 +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16 +# INT8 +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --int8 +``` + +## 性能对比 + +| 预测库 | 模型 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | +| :--------: | :--------: |:-------- |:--------: | :---------------------: | +| Paddle TensorRT | yolov5s | 5.95ms | 2.44ms | 1.87ms | +| TensorRT | yolov5s | 6.16ms | 2.58ms | 2.07ms | + +环境: +- Tesla T4,TensorRT 8.4.1,CUDA 11.2 +- batch_size=1 diff --git a/example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh b/example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh new file mode 100644 index 00000000..afff924b --- /dev/null +++ b/example/auto_compression/pytorch_yolov5/cpp_infer/compile.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set +x +set -e + +work_path=$(dirname $(readlink -f $0)) + +mkdir -p build +cd build +rm -rf * + +DEMO_NAME=trt_run + +WITH_MKL=ON +WITH_GPU=ON +USE_TENSORRT=ON + +LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ +CUDA_LIB=/usr/local/cuda/lib64 +TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ + +WITH_ROCM=OFF +ROCM_LIB=/opt/rocm/lib + +cmake .. -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=${WITH_MKL} \ + -DDEMO_NAME=${DEMO_NAME} \ + -DWITH_GPU=${WITH_GPU} \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=${USE_TENSORRT} \ + -DWITH_ROCM=${WITH_ROCM} \ + -DROCM_LIB=${ROCM_LIB} \ + -DCUDNN_LIB=${CUDNN_LIB} \ + -DCUDA_LIB=${CUDA_LIB} \ + -DTENSORRT_ROOT=${TENSORRT_ROOT} + +make -j diff --git a/example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc b/example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc new file mode 100644 index 00000000..0ae055ac --- /dev/null +++ b/example/auto_compression/pytorch_yolov5/cpp_infer/trt_run.cc @@ -0,0 +1,116 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "paddle/include/paddle_inference_api.h" +#include "paddle/include/experimental/phi/common/float16.h" + +using paddle_infer::Config; +using paddle_infer::Predictor; +using paddle_infer::CreatePredictor; +using paddle_infer::PrecisionType; +using phi::dtype::float16; + +DEFINE_string(model_dir, "", "Directory of the inference model."); +DEFINE_string(model_file, "", "Path of the inference model file."); +DEFINE_string(params_file, "", "Path of the inference params file."); +DEFINE_string(run_mode, "trt_fp32", "run_mode which can be: trt_fp32, trt_fp16 and trt_int8"); +DEFINE_int32(batch_size, 1, "Batch size."); +DEFINE_int32(gpu_id, 0, "GPU card ID num."); +DEFINE_int32(trt_min_subgraph_size, 3, "tensorrt min_subgraph_size"); +DEFINE_int32(warmup, 50, "warmup"); +DEFINE_int32(repeats, 1000, "repeats"); + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); }; +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +std::shared_ptr InitPredictor() { + Config config; + std::string model_path; + if (FLAGS_model_dir != "") { + config.SetModel(FLAGS_model_dir); + model_path = FLAGS_model_dir.substr(0, FLAGS_model_dir.find_last_of("/")); + } else { + config.SetModel(FLAGS_model_file, FLAGS_params_file); + model_path = FLAGS_model_file.substr(0, FLAGS_model_file.find_last_of("/")); + } + // enable tune + std::cout << "model_path: " << model_path << std::endl; + config.EnableUseGpu(256, FLAGS_gpu_id); + if (FLAGS_run_mode == "trt_fp32") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kFloat32, false, false); + } else if (FLAGS_run_mode == "trt_fp16") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kHalf, false, false); + } else if (FLAGS_run_mode == "trt_int8") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kInt8, false, false); + } + config.EnableMemoryOptim(); + config.SwitchIrOptim(true); + return CreatePredictor(config); +} + +template +void run(Predictor *predictor, const std::vector &input, + const std::vector &input_shape, type* out_data, std::vector out_shape) { + + // prepare input + int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); + + for (int i = 0; i < FLAGS_warmup; ++i) + CHECK(predictor->Run()); + + auto st = time(); + for (int i = 0; i < FLAGS_repeats; ++i) { + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); + + CHECK(predictor->Run()); + + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_t->shape(); + output_t -> ShareExternalData(out_data, out_shape, paddle_infer::PlaceType::kGPU); + } + + LOG(INFO) << "[" << FLAGS_run_mode << " bs-" << FLAGS_batch_size << " ] run avg time is " << time_diff(st, time()) / FLAGS_repeats + << " ms"; +} + +int main(int argc, char *argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + auto predictor = InitPredictor(); + std::vector input_shape = {FLAGS_batch_size, 3, 640, 640}; + // float16 + using dtype = float16; + std::vector input_data(FLAGS_batch_size * 3 * 640 * 640, dtype(1.0)); + + dtype *out_data; + int out_data_size = FLAGS_batch_size * 25200 * 85; + cudaHostAlloc((void**)&out_data, sizeof(float) * out_data_size, cudaHostAllocMapped); + + std::vector out_shape{ FLAGS_batch_size, 1, 25200, 85}; + run(predictor.get(), input_data, input_shape, out_data, out_shape); + return 0; +} diff --git a/example/auto_compression/pytorch_yolov5/eval.py b/example/auto_compression/pytorch_yolov5/eval.py index 8cc252cd..55be2feb 100644 --- a/example/auto_compression/pytorch_yolov5/eval.py +++ b/example/auto_compression/pytorch_yolov5/eval.py @@ -42,13 +42,6 @@ def argsparser(): return parser -def print_arguments(args): - print('----------- Running Arguments -----------') - for arg, value in sorted(vars(args).items()): - print('%s: %s' % (arg, value)) - print('------------------------------------------') - - def reader_wrapper(reader, input_list): def gen(): for data in reader: @@ -84,7 +77,7 @@ def eval(): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) - val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model( + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( global_config["model_dir"], exe, model_filename=global_config["model_filename"], @@ -160,7 +153,6 @@ if __name__ == '__main__': paddle.enable_static() parser = argsparser() FLAGS = parser.parse_args() - print_arguments(FLAGS) assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] paddle.set_device(FLAGS.devices) diff --git a/example/auto_compression/pytorch_yolov5/post_quant.py b/example/auto_compression/pytorch_yolov5/post_quant.py new file mode 100644 index 00000000..8c866727 --- /dev/null +++ b/example/auto_compression/pytorch_yolov5/post_quant.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.quant import quant_post_static + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='ptq_out', + help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + parser.add_argument( + '--algo', type=str, default='KL', help="post quant algo.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def main(): + global global_config + all_config = load_slim_config(FLAGS.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + train_loader = create('EvalReader')(reader_cfg['TrainDataset'], + reader_cfg['worker_num'], + return_list=True) + train_loader = reader_wrapper(train_loader, global_config['input_list']) + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + quant_post_static( + executor=exe, + model_dir=global_config["model_dir"], + quantize_model_path=FLAGS.save_dir, + data_loader=train_loader, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"], + batch_size=32, + batch_nums=10, + algo=FLAGS.algo, + hist_percent=0.999, + is_full_quantize=False, + bias_correction=False, + onnx_format=False) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/auto_compression/pytorch_yolov5/run.py b/example/auto_compression/pytorch_yolov5/run.py index 130003c0..965a546f 100644 --- a/example/auto_compression/pytorch_yolov5/run.py +++ b/example/auto_compression/pytorch_yolov5/run.py @@ -44,19 +44,10 @@ def argsparser(): type=str, default='gpu', help="which device used to compress.") - parser.add_argument( - '--eval', type=bool, default=False, help="whether to run evaluation.") return parser -def print_arguments(args): - print('----------- Running Arguments -----------') - for arg, value in sorted(vars(args).items()): - print('%s: %s' % (arg, value)) - print('------------------------------------------') - - def reader_wrapper(reader, input_list): def gen(): for data in reader: @@ -181,7 +172,6 @@ if __name__ == '__main__': paddle.enable_static() parser = argsparser() FLAGS = parser.parse_args() - print_arguments(FLAGS) assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] paddle.set_device(FLAGS.devices) diff --git a/example/auto_compression/pytorch_yolov6/README.md b/example/auto_compression/pytorch_yolov6/README.md index 662778ff..41a61b54 100644 --- a/example/auto_compression/pytorch_yolov6/README.md +++ b/example/auto_compression/pytorch_yolov6/README.md @@ -136,3 +136,8 @@ python paddle_trt_infer.py --model_path=output --image_file=images/000000570688. ``` ## 5.FAQ + +- 如果想测试离线量化模型精度,可执行: +```shell +python post_quant.py --config_path=./configs/yolov6s_qat_dis.yaml +``` diff --git a/example/auto_compression/pytorch_yolov6/post_quant.py b/example/auto_compression/pytorch_yolov6/post_quant.py index 7fa929dc..aa4f5d8f 100644 --- a/example/auto_compression/pytorch_yolov6/post_quant.py +++ b/example/auto_compression/pytorch_yolov6/post_quant.py @@ -39,6 +39,11 @@ def argsparser(): type=str, default='ptq_out', help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") parser.add_argument( '--algo', type=str, default='KL', help="post quant algo.") -- GitLab