diff --git a/README.md b/README.md index 0ee9d69463c4c8ef519473fb90ac38891a872c3d..5c22dd5e92d045a39e4b50802a21244328c250bf 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,7 @@ pip install paddleslim==1.2.0 -i https://pypi.tuna.tsinghua.edu.cn/simple - [SlimFaceNet](demo/slimfacenet/README.md) - [OCR模型压缩(基于PaddleOCR)](demo/ocr/README.md) - [检测模型压缩(基于PaddleDetection)](demo/detection/README.md) +- [TensorRT部署](demo/quant/deploy/TensorRT): 介绍如何使用TensorRT部署PaddleSlim量化得到的模型。 ## 部分压缩策略效果 diff --git a/README_en.md b/README_en.md index cb7a849c604371d4ad51218332c96740afaa08c4..09e9784f22fd47d0fb3edabbda4f2aa56fbd8036 100644 --- a/README_en.md +++ b/README_en.md @@ -90,6 +90,7 @@ pip install paddleslim==1.2.0 -i https://pypi.tuna.tsinghua.edu.cn/simple - [PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection/tree/master/slim): Introduce how to use PaddleSlim in PaddleDetection library. - [PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/slim): Introduce how to use PaddleSlim in PaddleSeg library. - [PaddleLite](https://paddlepaddle.github.io/Paddle-Lite/): How to use PaddleLite to deploy models generated by PaddleSlim. +- [TensorRT Deploy](demo/quant/deploy/TensorRT): How to use TensorRT to deploy models generated by PaddleSlim. ## Performance diff --git a/demo/quant/deploy/TensorRT/CMakeLists.txt b/demo/quant/deploy/TensorRT/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8dce608fb7f4359641ac0138ddefe4fad884aa8e --- /dev/null +++ b/demo/quant/deploy/TensorRT/CMakeLists.txt @@ -0,0 +1,104 @@ +cmake_minimum_required(VERSION 3.0) +project(inference_test CXX C) +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." OFF) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") +endif() + +# check file system +file(READ "/etc/issue" ETC_ISSUE) +string(REGEX MATCH "Debian|Ubuntu|CentOS" DIST ${ETC_ISSUE}) + +if(DIST STREQUAL "Debian") + message(STATUS ">>>> Found Debian <<<<") +elseif(DIST STREQUAL "Ubuntu") + message(STATUS ">>>> Found Ubuntu <<<<") +elseif(DIST STREQUAL "CentOS") + message(STATUS ">>>> Found CentOS <<<<") +else() + message(STATUS ">>>> Found unknown distribution <<<<") +endif() + +include_directories("${PADDLE_LIB}/") +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}zlib/include") +include_directories("${PADDLE_LIB}/third_party/boost") +include_directories("${PADDLE_LIB}/third_party/eigen3") +include_directories("${PADDLE_LIB}/paddle/include") +include_directories("${PADDLE_LIB}/paddle") + +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}zlib/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") +link_directories("${PADDLE_LIB}/paddle/lib") + + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -std=c++11") +message("flags" ${CMAKE_CXX_FLAGS}) + +if (USE_TENSORRT AND WITH_GPU) + message("=====> TENSORRT_INCLUDE_DIR is ${TENSORRT_INCLUDE_DIR}") + message("=====> TENSORRT_LIB_DIR is ${TENSORRT_LIB_DIR}") + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") +endif() + +if(WITH_MKL) + set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") + include_directories("${MATH_LIB_PATH}/include") + set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") + if(EXISTS ${MKLDNN_PATH}) + include_directories("${MKLDNN_PATH}/include") + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) + endif() +else() + set(MATH_LIB ${PADDLE_LIB_THIRD_PARTY_PATH}openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) +endif() + +# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a +if(WITH_STATIC_LIB) + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) +else() + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +set(EXTERNAL_LIB "-lrt -ldl -lpthread -lprotobuf") +set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf z xxhash + ${EXTERNAL_LIB}) + +if(WITH_GPU) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(DEPS ${DEPS} $ENV{CUDA_LIB}) +endif() + +set(TestFiles "trt_clas.cc"; + "trt_gen_calib_table_test.cc"; + "test_acc.cc";) + +foreach(testsourcefile ${TestFiles}) + message("====> ${testsourcefile} will be compiled") + # add executable for all test files + string(REPLACE ".cc" "" testname ${testsourcefile}) + add_executable(${testname} ${testsourcefile}) + + # link libs + target_link_libraries(${testname} ${DEPS}) + +endforeach() diff --git a/demo/quant/deploy/TensorRT/README.md b/demo/quant/deploy/TensorRT/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a74db14363a8d03a5db81a96bf91550d18d5421b --- /dev/null +++ b/demo/quant/deploy/TensorRT/README.md @@ -0,0 +1,230 @@ +# PaddleSlim量化模型的TensorRT预测 + +本教程将介绍使用TensortRT部署PaddleSlim量化得到的模型的详细步骤。 + + +## 1. 准备环境 + +* 有2种方式获取Paddle预测库,下面进行详细介绍。 + +### 1.1 直接下载安装 + +* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择带有TensorRT的预测库版本。 + +* 下载之后使用下面的方法解压。 + +``` +tar -xf fluid_inference.tgz +``` + +最终会在当前的文件夹中生成`fluid_inference/`的子文件夹。 + + +### 1.2 预测库源码编译 +* 如果希望获取最新预测库特性,可以从Paddle github上克隆最新代码,源码编译预测库。 +* 可以参考[Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)的说明,从github上获取Paddle代码,然后进行编译,生成最新的预测库。使用git获取代码方法如下。 + +```shell +git clone https://github.com/PaddlePaddle/Paddle.git +``` +* 在[Nvidia官网](https://developer.nvidia.com/TensorRT)下载TensorRT并解压, 本示例以TensorRT 6.0为例。 + +* 进入Paddle目录后,编译方法如下。 + +```shell +rm -rf build +mkdir build +cd build + +cmake .. \ + -DWITH_MKL=ON \ + -DWITH_MKLDNN=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DWITH_INFERENCE_API_TEST=OFF \ + -DTENSORRT_ROOT=TensorRT-6.0.1.5 \ + -DFLUID_INFERENCE_INSTALL_DIR=LIB_ROOT \ + -DON_INFER=ON \ + -DWITH_PYTHON=ON +make -j +make inference_lib_dist +``` + +其中`DFLUID_INFERENCE_INSTALL_DIR`代表编译完成后预测库生成的地址,`DTENSORRT_ROOT`代表下载解压后的TensorRT路径。 + +更多编译参数选项可以参考Paddle C++预测库官网:[https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)。 + + +* 编译完成之后,可以在`LIB_ROOT`路径下看到生成了以下文件及文件夹。 + +``` +LIB_ROOT/ +|-- CMakeCache.txt +|-- paddle +|-- third_party +|-- version.txt +``` + +其中`paddle`就是之后进行TensorRT预测时所需的Paddle库,`version.txt`中包含当前预测库的版本信息。 + + +## 2 开始运行 + +### 2.1 将模型导出为inference model + +* 可以参考[量化训练教程](https://paddleslim.readthedocs.io/zh_CN/latest/quick_start/quant_aware_tutorial.html#id9),在训练完成后导出inference model。 + +``` +inference/ +|-- model +|-- params +``` + + +### 2.2 编译TensorRT预测demo + +* 编译命令如下,其中Paddle, TensorRT地址需要换成自己机器上的实际地址。 + + +```shell +sh tools/build.sh +``` + +具体地,`tools/build.sh`中内容如下。 + +```shell +PADDLE_LIB_PATH=trt_inference # change to your path +USE_GPU=ON +USE_MKL=ON +USE_TRT=ON +TENSORRT_INCLUDE_DIR=TensorRT-6.0.1.5/include # change to your path +TENSORRT_LIB_DIR=TensorRT-6.0.1.5/lib # change to your path + +if [ $USE_GPU -eq ON ]; then + export CUDA_LIB=`find /usr/local -name libcudart.so` +fi +BUILD=build +mkdir -p $BUILD +cd $BUILD +cmake .. \ + -DPADDLE_LIB=${PADDLE_LIB_PATH} \ + -DWITH_GPU=${USE_GPU} \ + -DWITH_MKL=${USE_MKL} \ + -DCUDA_LIB=${CUDA_LIB} \ + -DUSE_TENSORRT=${USE_TRT} \ + -DTENSORRT_INCLUDE_DIR=${TENSORRT_INCLUDE_DIR} \ + -DTENSORRT_LIB_DIR=${TENSORRT_LIB_DIR} +make -j4 +``` + +`PADDLE_LIB_PATH`为下载(`fluid_inference`文件夹)或者编译生成的Paddle预测库地址(`build/fluid_inference_install_dir`文件夹);`TENSORRT_INCLUDE_DIR`和`TENSORRT_LIB_DIR`分别代表TensorRT的include和lib目录路径。 + + +* 编译完成之后,会在`build`文件夹下生成可执行文件。 + + +### 2.3 数据预处理转化 + +在精度和性能预测中,需要先对数据进行二进制转化。运行脚本如下可转化完整ILSVRC2012 val数据集。使用`--local`可以转化用户自己的数据。在Paddle所在目录运行下面的脚本。脚本在官网位置为[full_ILSVRC2012_val_preprocess.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py) +``` +python Paddle/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py --local --data_dir=/PATH/TO/USER/DATASET/ --output_file=/PATH/TO/SAVE/BINARY/FILE +``` + +可选参数: +- 不设置任何参数。脚本将下载 ILSVRC2012_img_val数据集,并转化为二进制文件。 +- **local:** 设置便为true,表示用户将提供自己的数据 +- **data_dir:** 用户自己的数据目录 +- **label_list:** 图片路径-图片类别列表文件,类似于`val_list.txt` +- **output_file:** 生成的binary文件路径。 +- **data_dim:** 预处理图片的长和宽。默认值 224。 + +用户自己的数据集目录结构应该如下 +``` +imagenet_user +├── val +│   ├── ILSVRC2012_val_00000001.jpg +│   ├── ILSVRC2012_val_00000002.jpg +| |── ... +└── val_list.txt +``` +其中,val_list.txt 内容应该如下: +``` +val/ILSVRC2012_val_00000001.jpg 0 +val/ILSVRC2012_val_00000002.jpg 0 +``` + +注意: +- 为什么将数据集转化为二进制文件?因为paddle中的数据预处理(resize, crop等)都使用pythong.Image模块进行,训练出的模型也是基于Python预处理的图片,但是我们发现Python测试性能开销很大,导致预测性能下降。为了获得良好性能,在量化模型预测阶段,我们需要使用C++测试,而C++只支持Open-CV等库,Paddle不建议使用外部库,因此我们使用Python将图片预处理然后放入二进制文件,再在C++测试中读出。用户根据自己的需要,可以更改C++测试以直接读数据并预处理,精度不会有太大下降。 + +### 2.4 部署预测 + + +### 运行demo +* 执行以下命令,完成一个分类模型的TensorRT预测。 + +```shell +sh tools/run.sh +``` +其中`MODEL_DIR`和`DATA_FILE`分别代表模型文件和数据文件, 需要在预测时替换为自己实际要用的地址。 + +可以看到类似下面的预测结果: + +```shell +I1123 11:30:49.160024 10999 trt_clas.cc:103] finish prediction +I1123 11:30:49.160050 10999 trt_clas.cc:136] pred image class is : 65, ground truth label is : 65 +``` + +* 修改`tools/run.sh`中的repeat_times大于1,通过多次预测取平均完成对一个模型的TensorRT速度评测。 + +```shell +sh tools/run.sh +``` + +可以看到类似下面的评测结果: + +```shell +I1123 11:40:30.936796 11681 trt_clas.cc:83] finish warm up 10 times +I1123 11:40:30.947906 11681 trt_clas.cc:101] total predict cost is : 11.042 ms, repeat 10 times +I1123 11:40:30.947947 11681 trt_clas.cc:102] average predict cost is : 1.1042 ms +``` + + +* 执行以下命令,完成对一个模型的TensorRT精度评测。 + +```shell +sh tools/test_acc.sh +``` + +同上,在预测时需要将其中路径替换为自己实际要用的地址。 + +可以看到类似下面的评测结果: + +```shell +I1123 11:23:11.856046 10913 test_acc.cc:64] 5000 +I1123 11:23:50.318663 10913 test_acc.cc:64] 10000 +I1123 11:24:28.793603 10913 test_acc.cc:64] 15000 +I1123 11:25:07.277580 10913 test_acc.cc:64] 20000 +I1123 11:25:45.698241 10913 test_acc.cc:64] 25000 +I1123 11:26:24.195798 10913 test_acc.cc:64] 30000 +I1123 11:27:02.625052 10913 test_acc.cc:64] 35000 +I1123 11:27:41.178545 10913 test_acc.cc:64] 40000 +I1123 11:28:19.798691 10913 test_acc.cc:64] 45000 +I1123 11:28:58.457620 10913 test_acc.cc:107] final result: +I1123 11:28:58.457688 10913 test_acc.cc:108] top1 acc:0.70664 +I1123 11:28:58.457712 10913 test_acc.cc:109] top5 acc:0.89494 +``` + + +## 3 Benchmark + +GPU: NVIDIA® Tesla® P4 + +数据集: ImageNet-2012 + +预测引擎: Paddle-TensorRT + + +| 模型 | FP32精度(Top1/Top5) | INT8精度(Top1/Top5) | FP32预测时延(ms) | INT8预测时延(ms) | 量化加速比 | +| :---------: | :-----------------: | :-----------------: | :----------: | :----------: | :--------: | +| MobileNetV1 | 71.00%/89.69% | 70.66%/89.27% | 1.083 | 0.568 | 47.55% | +| MobileNetV2 | 72.16%/90.65% | 71.09%/90.16% | 1.821 | 0.980 | 46.19% | +| ResNet50 | 76.50%/93.00% | 76.27%/92.95% | 4.960 | 2.014 | 59.39% | diff --git a/demo/quant/deploy/TensorRT/test_acc.cc b/demo/quant/deploy/TensorRT/test_acc.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb0bc863e8ce224a3f872e0d6e308aca8bd3a54f --- /dev/null +++ b/demo/quant/deploy/TensorRT/test_acc.cc @@ -0,0 +1,121 @@ +#include +#include +#include // std::sort +#include // std::iota +#include +#include +#include +#include +#include +#include +#include +#include "paddle/include/paddle_inference_api.h" + +namespace paddle { +using paddle::AnalysisConfig; + +DEFINE_string(model_dir, "resnet50_quant", "Directory of the inference model."); +DEFINE_string(data_dir, "imagenet-eval-binary", "Directory of the data."); +DEFINE_bool(int8, true, "use int8 or not"); + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); }; +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +void PrepareTRTConfig(AnalysisConfig *config, int batch_size, int id = 0) { + int min_subgraph_size = 3; + config->SetModel(FLAGS_model_dir + "/model", FLAGS_model_dir + "/params"); + config->EnableUseGpu(1500, id); + // We use ZeroCopyTensor here, so we set config->SwitchUseFeedFetchOps(false) + config->SwitchUseFeedFetchOps(false); + config->SwitchIrDebug(true); + if (FLAGS_int8) + config->EnableTensorRtEngine(1 << 30, batch_size, min_subgraph_size, AnalysisConfig::Precision::kInt8, false, false); + else + config->EnableTensorRtEngine(1 << 30, batch_size, min_subgraph_size, AnalysisConfig::Precision::kFloat32, false, false); + +} + +bool test_map_cnn(int batch_size, int repeat) { + AnalysisConfig config; + PrepareTRTConfig(&config, batch_size); + auto predictor = CreatePaddlePredictor(config); + + int channels = 3; + int height = 224; + int width = 224; + int input_num = channels * height * width * batch_size; + + // prepare inputs + float *input = new float[input_num]; + memset(input, 0, input_num * sizeof(float)); + float test_num = 0; + float top1_num = 0; + float top5_num = 0; + std::vector index(1000); + + for (size_t ind = 0; ind < 50000; ind++) { + if(ind % 5000 == 0) + LOG(INFO) << ind; + std::ifstream fs(FLAGS_data_dir + "/" + std::to_string(ind) + ".data", std::ifstream::binary); + if (!fs.is_open()) { + LOG(FATAL) << "open input file fail."; + } + auto input_data_tmp = input; + for (int i = 0; i < input_num; ++i) { + fs.read(reinterpret_cast(input_data_tmp), sizeof(*input_data_tmp)); + input_data_tmp++; + } + int label = 0; + fs.read(reinterpret_cast(&label), sizeof(label)); + fs.close(); + + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({batch_size, channels, height, width}); + input_t->copy_from_cpu(input); + + CHECK(predictor->ZeroCopyRun()); + std::vector out_data; + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + + out_data.resize(out_num); + output_t->copy_to_cpu(out_data.data()); + + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), [out_data](size_t i1, size_t i2) { + return out_data[i1] > out_data[i2]; + }); + test_num++; + if (label == index[0]) { + top1_num++; + } + for (int i = 0; i < 5; i++) { + if (label == index[i]) { + top5_num++; + } + } + } + LOG(INFO) << "final result:"; + LOG(INFO) << "top1 acc:" << top1_num / test_num; + LOG(INFO) << "top5 acc:" << top5_num / test_num; + + return true; +} +} // namespace paddle + +int main(int argc,char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, false); + for (int i = 0; i < 1; i++) { + paddle::test_map_cnn(1 << i, 1); + } + return 0; +} diff --git a/demo/quant/deploy/TensorRT/tools/build.sh b/demo/quant/deploy/TensorRT/tools/build.sh new file mode 100755 index 0000000000000000000000000000000000000000..481f8803b21e9951cdf052c1d93a609cbd8f8fc1 --- /dev/null +++ b/demo/quant/deploy/TensorRT/tools/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +PADDLE_LIB_PATH=trt_inference # change to your path +USE_GPU=ON +USE_MKL=ON +USE_TRT=ON +TENSORRT_INCLUDE_DIR=TensorRT-6.0.1.5/include # change to your path +TENSORRT_LIB_DIR=TensorRT-6.0.1.5/lib # change to your path + +if [ $USE_GPU -eq ON ]; then + export CUDA_LIB=`find /usr/local -name libcudart.so` +fi +rm -rf build +BUILD=build +mkdir -p $BUILD +cd $BUILD +cmake .. \ + -DPADDLE_LIB=${PADDLE_LIB_PATH} \ + -DWITH_GPU=${USE_GPU} \ + -DWITH_MKL=${USE_MKL} \ + -DCUDA_LIB=${CUDA_LIB} \ + -DUSE_TENSORRT=${USE_TRT} \ + -DTENSORRT_INCLUDE_DIR=${TENSORRT_INCLUDE_DIR} \ + -DTENSORRT_LIB_DIR=${TENSORRT_LIB_DIR} +make -j4 diff --git a/demo/quant/deploy/TensorRT/tools/run.sh b/demo/quant/deploy/TensorRT/tools/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..ace71f8e1150251d716c178abeba158a8f8a7959 --- /dev/null +++ b/demo/quant/deploy/TensorRT/tools/run.sh @@ -0,0 +1,24 @@ +#!/bin/bash +project_path=$(cd "$(dirname "$0")";pwd) +echo "${project_path}" +unset GREP_OPTIONS; +if [ ! -d "./build" ];then + echo -e "\033[33m run failed! \033[0m"; + echo -e "\033[33m you should build first \033[0m"; + exit; +fi + +MODEL_DIR=MobileNetV1-quant # change to your model +BATCH_SIZE=1 +USE_CALIB=false +USE_INT8=true +DATA_FILE=imagenet-eval-binary/0.data # change to your data file + + +build/trt_clas --model_dir=${MODEL_DIR} \ + --batch_size=${BATCH_SIZE} \ + --use_calib=${USE_CALIB} \ + --use_int8=${USE_INT8} \ + --data_file=${DATA_FILE} \ + --repeat_times=1 + diff --git a/demo/quant/deploy/TensorRT/tools/test_acc.sh b/demo/quant/deploy/TensorRT/tools/test_acc.sh new file mode 100755 index 0000000000000000000000000000000000000000..73aba70f819d4c3fa0f18e78a3d701b391db1b8d --- /dev/null +++ b/demo/quant/deploy/TensorRT/tools/test_acc.sh @@ -0,0 +1,14 @@ +#!/bin/bash +unset GREP_OPTIONS; +if [ ! -d "./build" ];then + echo -e "\033[33m run failed! \033[0m"; + echo -e "\033[33m you should build first \033[0m"; + exit; +fi + +MODEL_DIR=MobileNetV1-quant # change to your model_dir +DATA_DIR=imagenet-eval-binary # chage to your data_dir +USE_INT8=True + +./build/test_acc --model_dir=$MODEL_DIR --data_dir=$DATA_DIR --int8=$USE_INT8 + diff --git a/demo/quant/deploy/TensorRT/trt_clas.cc b/demo/quant/deploy/TensorRT/trt_clas.cc new file mode 100644 index 0000000000000000000000000000000000000000..72e25fcd267bc4ebe874e5cda5fc0f996610d5ac --- /dev/null +++ b/demo/quant/deploy/TensorRT/trt_clas.cc @@ -0,0 +1,140 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "paddle/include/paddle_inference_api.h" + +DEFINE_string(model_dir, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size."); +DEFINE_bool(use_calib, true, "Whether to use calib. Set to true if you are using TRT calibration; \ + Set to false if you are using PaddleSlim quant models."); +DEFINE_string(data_file, "", "Path of the inference data file."); +DEFINE_bool(use_int8, true, "use trt int8 or not"); +DEFINE_int32(repeat_times, 1000, "benchmark repeat time"); + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); }; +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +std::pair findMax(std::vector& vec) { + float max = -1000; + int idx = -1; + for (int i=0; i max) { + max = vec[i]; + idx=i; + } + } + return {idx, max}; +} + +std::unique_ptr CreatePredictor() { + paddle::AnalysisConfig config; + config.SetModel(FLAGS_model_dir + "/model", FLAGS_model_dir + "/params"); + config.EnableUseGpu(500, 0); + // We use ZeroCopy, so we set config.SwitchUseFeedFetchOps(false) here. + config.SwitchUseFeedFetchOps(false); + if (FLAGS_use_int8){ + config.EnableTensorRtEngine(1 << 30, \ + FLAGS_batch_size, \ + 5, \ + paddle::AnalysisConfig::Precision::kInt8, \ + false, \ + FLAGS_use_calib); + } + else{ + config.EnableTensorRtEngine(1 << 30, \ + FLAGS_batch_size, \ + 5, \ + paddle::AnalysisConfig::Precision::kFloat32, \ + false, \ + false); + } + return CreatePaddlePredictor(config); +} + +void run(paddle::PaddlePredictor *predictor, + const std::vector& input, + const std::vector& input_shape, + std::vector *out_data) { + int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape(input_shape); + input_t->copy_from_cpu(input.data()); + + // warm up 10 times + if (FLAGS_repeat_times != 1) { + for(int i=0; i<10; i++){ + CHECK(predictor->ZeroCopyRun()); + } + LOG(INFO) << "finish warm up 10 times"; + } + + auto time1 = time(); + for(int i=0; iZeroCopyRun()); + auto output_names = predictor->GetOutputNames(); + // there is only one output of Resnet50 + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + + out_data->resize(out_num); + output_t->copy_to_cpu(out_data->data()); + } + auto time2 = time(); + auto total_time = time_diff(time1, time2); + auto average_time = total_time / FLAGS_repeat_times; + LOG(INFO) << "total predict cost is : " << total_time << " ms, repeat " << FLAGS_repeat_times << " times"; + LOG(INFO) << "average predict cost is : " << average_time << " ms"; + LOG(INFO) << "finish prediction"; +} + +int main(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + auto predictor = CreatePredictor(); + std::vector input_shape = {FLAGS_batch_size, 3, 224, 224}; + // Init input as 1.0 here for example. You can also load preprocessed real pictures to vectors as input. + // std::vector input_vec(FLAGS_batch_size * 3 * 224 * 224, 1.0); + + int input_num = FLAGS_batch_size * 3 * 224 * 224; + float *input_data = new float[input_num]; + memset(input_data, 0, input_num * sizeof(float)); + + std::ifstream fs(FLAGS_data_file, std::ifstream::binary); + if (!fs.is_open()) { + LOG(FATAL) << "open input file fail."; + } + auto input_data_tmp = input_data; + for (int i = 0; i < input_num; ++i) { + fs.read(reinterpret_cast(input_data_tmp), sizeof(*input_data_tmp)); + input_data_tmp++; + } + int label = 0; + fs.read(reinterpret_cast(&label), sizeof(label)); + fs.close(); + + std::vector input_vec {input_data, input_data + input_num}; + + std::vector out_data; + run(predictor.get(), input_vec, input_shape, &out_data); + if (FLAGS_batch_size == 1) { + std::pair result = findMax(out_data); + LOG(INFO) << "pred image class is : " << result.first << ", ground truth label is : " << label; + } + return 0; +} + diff --git a/demo/quant/deploy/TensorRT/trt_gen_calib_table_test.cc b/demo/quant/deploy/TensorRT/trt_gen_calib_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5ffc3bf657ddff46e059df5656d0f276197c34b --- /dev/null +++ b/demo/quant/deploy/TensorRT/trt_gen_calib_table_test.cc @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "paddle/include/paddle_inference_api.h" + +using paddle::AnalysisConfig; + +DEFINE_string(model_file, "", "Path of the inference model file."); +DEFINE_string(params_file, "", "Path of the inference params file."); +DEFINE_string(model_dir, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size."); + +float Random(float low, float high) { + static std::random_device rd; + static std::mt19937 mt(rd()); + std::uniform_real_distribution dist(low, high); + return dist(mt); +} + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); }; +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +std::unique_ptr CreatePredictor() { + AnalysisConfig config; + if (FLAGS_model_dir != "") { + config.SetModel(FLAGS_model_dir); + } else { + config.SetModel(FLAGS_model_file, + FLAGS_params_file); + } + config.EnableUseGpu(500, 0); + // We use ZeroCopy, so we set config.SwitchUseFeedFetchOps(false) here. + config.SwitchUseFeedFetchOps(false); + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, 5, AnalysisConfig::Precision::kInt8, false, true /*use_calib*/); + return CreatePaddlePredictor(config); +} + +void run(paddle::PaddlePredictor *predictor, + std::vector& input, + const std::vector& input_shape, + std::vector *out_data) { + int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + for (size_t i = 0; i < 500; i++) { + // We use random data here for example. Change this to real data in your application. + for (int j = 0; j < input_num; j++) { + input[j] = Random(0, 1.0); + } + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape(input_shape); + input_t->copy_from_cpu(input.data()); + + // Run predictor to generate calibration table. Can be very time-consuming. + CHECK(predictor->ZeroCopyRun()); + } +} + +int main(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + auto predictor = CreatePredictor(); + std::vector input_shape = {FLAGS_batch_size, 3, 224, 224}; + // Init input as 1.0 here for example. You can also load preprocessed real pictures to vectors as input. + std::vector input_data(FLAGS_batch_size * 3 * 224 * 224, 1.0); + std::vector out_data; + run(predictor.get(), input_data, input_shape, &out_data); + return 0; +} +