From aa9e0d546f3e2bd3abc8e50e8df46f0585061913 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 11 May 2021 17:20:23 +0800 Subject: [PATCH] add inference benchmark (#2508) * add inference benchmark --- deploy/README.md | 68 ++--- deploy/TENSOR_RT.md | 2 +- deploy/benchmark/benchmark.sh | 36 +++ deploy/benchmark/benchmark_quant.sh | 23 ++ deploy/benchmark/log_parser_excel.py | 301 +++++++++++++++++++ deploy/cpp/docs/Jetson_build.md | 9 +- deploy/cpp/docs/linux_build.md | 9 +- deploy/cpp/docs/windows_vs2019_build.md | 9 +- deploy/cpp/include/object_detector.h | 42 ++- deploy/cpp/src/main.cc | 181 ++++++++--- deploy/cpp/src/object_detector.cc | 60 ++-- deploy/python/README.md | 57 +--- deploy/python/infer.py | 252 ++++++++-------- deploy/python/preprocess.py | 11 +- deploy/python/utils.py | 262 ++++++++++++++++ {deploy/imgs => docs/images}/input_shape.png | Bin ppdet/engine/export_utils.py | 5 + 17 files changed, 1016 insertions(+), 311 deletions(-) create mode 100644 deploy/benchmark/benchmark.sh create mode 100644 deploy/benchmark/benchmark_quant.sh create mode 100644 deploy/benchmark/log_parser_excel.py create mode 100644 deploy/python/utils.py rename {deploy/imgs => docs/images}/input_shape.png (100%) diff --git a/deploy/README.md b/deploy/README.md index b026ded94..50e4d50ec 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -1,19 +1,16 @@ # PaddleDetection 预测部署 -训练得到一个满足要求的模型后,如果想要将该模型部署到已选择的平台上,需要通过`tools/export_model.py`将模型导出预测部署的模型和配置文件。 -并在同一文件夹下导出预测时使用的配置文件,配置文件名为`infer_cfg.yml`。 - -## 1、`PaddleDetection`目前支持的部署方式按照部署设备可以分为: -- 在本机`python`语言部署,支持在有`python paddle`(支持`CPU`、`GPU`)环境下部署,有两种方式: - - 使用`tools/infer.py`,此种方式依赖`PaddleDetection`代码库。 - - 将模型导出,使用`deploy/python/infer.py`,此种方式不依赖`PaddleDetection`代码库,可以单个`python`文件部署。 -- 在本机`C++`语言使用`paddle inference`预测库部署,支持在`Linux`和`Windows`系统下部署。请参考文档[C++部署](cpp/README.md)。 -- 在服务器端以服务形式部署,使用[PaddleServing](./serving/README.md)部署。 -- 在手机移动端部署,使用[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在手机移动端部署。 - 常见模型部署Demo请参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo) 。 -- `NV Jetson`嵌入式设备上部署 -- `TensorRT`加速请参考文档[TensorRT预测部署教程](TENSOR_RT.md) - -## 2、模型导出 + +目前支持的部署方式有: +- `Paddle Inference预测库`部署: + - `Python`语言部署,支持`CPU`、`GPU`和`XPU`环境,参考文档[python部署](python/README.md)。 + - `C++`语言部署 ,支持`CPU`、`GPU`和`XPU`环境,支持在`Linux`、`Windows`系统下部署,支持`NV Jetson`嵌入式设备上部署。请参考文档[C++部署](cpp/README.md)。 + - `TensorRT`加速:请参考文档[TensorRT预测部署教程](TENSOR_RT.md) +- 服务器端部署:使用[PaddleServing](./serving/README.md)部署。 +- 手机移动端部署:使用[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在手机移动端部署。 + + +## 1.模型导出 + 使用`tools/export_model.py`脚本导出模型已经部署时使用的配置文件,配置文件名字为`infer_cfg.yml`。模型导出脚本如下: ```bash # 导出YOLOv3模型 @@ -29,47 +26,50 @@ python tools/export_model.py -c configs/yolov3/yolov3_darknet53_270e_coco.yml -o 模型导出具体请参考文档[PaddleDetection模型导出教程](EXPORT_MODEL.md)。 -## 3、如何选择部署时依赖库的版本 - -### (1)CUDA、cuDNN、TensorRT版本选择 -由于CUDA、cuDNN、TENSORRT不一定都是向前兼容的,需要使用与编译Paddle预测库使用的环境完全一致的环境进行部署。 - -### (2)部署时预测库版本、预测引擎版本选择 +## 2.部署环境准备 -- Linux、Windows平台下C++部署,需要使用Paddle预测库进行部署。 - (1)Paddle官网提供在不同平台、不同环境下编译好的预测库,您可以直接使用,请在这里[Paddle预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 选择。 - (2)如果您将要部署的平台环境,Paddle官网上没有提供已编译好的预测库,您可以自行编译,编译过程请参考[Paddle源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html)。 +- Python预测:在python环境下安装PaddlePaddle环境即可,如需TensorRT预测,在[Paddle Release版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release)中下载合适的wheel包即可。 +- C++预测库:请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html),如果需要使用TensorRT,请下载带有TensorRT编译的预测库。您也可以自行编译,编译过程请参考[Paddle源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html)。 **注意:** Paddle预测库版本需要>=2.0 -- Python语言部署,需要在对应平台上安装Paddle Python包。如果Paddle官网上没有提供该平台下的Paddle Python包,您可以自行编译,编译过程请参考[Paddle源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html)。 - - PaddleServing部署 - PaddleServing 0.4.0是基于Paddle 1.8.4开发,PaddleServing 0.5.0是基于Paddle2.0开发。 + 请选择PaddleServing>0.5.0以上版本,具体可参考[PaddleServing安装文档](https://github.com/PaddlePaddle/Serving/blob/develop/README.md#installation)。 - Paddle-Lite部署 Paddle-Lite支持OP列表请参考:[Paddle-Lite支持的OP列表](https://paddle-lite.readthedocs.io/zh/latest/source_compile/library.html) ,请跟进所部署模型中使用到的op选择Paddle-Lite版本。 - NV Jetson部署 - Paddle官网提供在NV Jetson平台上已经编译好的预测库,[Paddle NV Jetson预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 。 - 若列表中没有您需要的预测库,您可以在您的平台上自行编译,编译过程请参考[Paddle源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html)。 + Paddle官网提供在NV Jetson平台上已经编译好的预测库,[Paddle NV Jetson预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html)。若列表中没有您需要的预测库,您可以在您的平台上自行编译,编译过程请参考[Paddle源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html)。 +## 3.部署预测 +- Python部署:使用`deploy/python/infer.py`进行预测,可具体参考[python部署文档](python/README.md)。 +```shell +python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/image --use_gpu=(False/True) +``` -## 4、部署 -- C++部署,先使用跨平台编译工具`CMake`根据`CMakeLists.txt`生成`Makefile`,支持`Windows、Linux、NV Jetson`平台,然后进行编译产出可执行文件。可以直接使用`cpp/scripts/build.sh`脚本编译: +- C++部署,先使用跨平台编译工具`CMake`根据`CMakeLists.txt`生成`Makefile`,支持[Windows](cpp/docs/windows_vs2019_build.md)、[Linux](cpp/docs/linux_build.md)、[NV Jetson](cpp/docs/Jetson_build.md)平台部署,然后进行编译产出可执行文件。可以直接使用`cpp/scripts/build.sh`脚本编译: ```buildoutcfg cd cpp sh scripts/build.sh ``` -- Python部署,可以使用使用`tools/infer.py`(以来PaddleDetection源码)部署,或者使用`deploy/python/infer.py`单文件部署 - - PaddleServing部署请参考,[PaddleServing部署](./serving/README.md)部署。 - 手机移动端部署,请参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)部署。 +## 4.BenchMark测试 +- 使用导出的模型,运行BenchMark批量测试脚本: +```shell +sh deploy/benchmark/benchmark_ppdet.sh {model_dir} {model_name} +``` +**注意** 如果是量化模型,请使用`deploy/benchmark/benchmark_ppdet_quant.sh`脚本。 +- 将测试结果log导出至Excel中: +``` +python deploy/benchmark/log_parser_excel.py --log_path=./output_pipeline --output_name=benchmark_excel.xlsx +``` -## 5、常见问题QA +## 5.常见问题QA - 1、`Paddle 1.8.4`训练的模型,可以用`Paddle2.0`部署吗? Paddle 2.0是兼容Paddle 1.8.4的,因此是可以的。但是部分模型(如SOLOv2)使用到了Paddle 2.0中新增OP,这类模型不可以。 diff --git a/deploy/TENSOR_RT.md b/deploy/TENSOR_RT.md index 9d97cf294..986933576 100644 --- a/deploy/TENSOR_RT.md +++ b/deploy/TENSOR_RT.md @@ -38,7 +38,7 @@ TensorRT版本<=5时,使用TensorRT预测时,只支持固定尺寸输入。 `TestReader.inputs_def.image_shape`设置的是输入TensorRT引擎的数据尺寸(在像FasterRCNN中,`TestReader.inputs_def.image_shape`指定的是在`Pad`操作之前的图像数据尺寸)。 可以通过[visualdl](https://www.paddlepaddle.org.cn/paddle/visualdl/demo/graph) 打开`model.pdmodel`文件,查看输入的第一个Tensor尺寸是否是固定的,如果不指定,尺寸会用`?`表示,如下图所示: -![img](imgs/input_shape.png) +![img](../docs/images/input_shape.png) 同时需要将图像预处理后的尺寸与设置车模型输入尺寸保持一致,需要设置`infer_cfg.yml`配置文件中`Resize OP`的`target_size`参数和`keep_ratio`参数。 diff --git a/deploy/benchmark/benchmark.sh b/deploy/benchmark/benchmark.sh new file mode 100644 index 000000000..84840892f --- /dev/null +++ b/deploy/benchmark/benchmark.sh @@ -0,0 +1,36 @@ +# All rights `PaddleDetection` reserved +#!/bin/bash +model_dir=$1 +model_name=$2 + +export img_dir="demo" +export log_path="output_pipeline" + + +echo "model_dir : ${model_dir}" +echo "img_dir: ${img_dir}" + +# TODO: support batch size>1 +for use_mkldnn in "True" "False"; do + for threads in "1" "6"; do + echo "${model_name} ${model_dir}, use_mkldnn: ${use_mkldnn} threads: ${threads}" + python deploy/python/infer.py \ + --model_dir=${model_dir} \ + --run_benchmark True \ + --enable_mkldnn=${use_mkldnn} \ + --use_gpu=False \ + --cpu_threads=${threads} \ + --image_dir=${img_dir} 2>&1 | tee ${log_path}/${model_name}_cpu_usemkldnn_${use_mkldnn}_cputhreads_${threads}_bs1_infer.log + done +done + +for run_mode in "fluid" "trt_fp32" "trt_fp16"; do + echo "${model_name} ${model_dir}, run_mode: ${run_mode}" + python deploy/python/infer.py \ + --model_dir=${model_dir} \ + --run_benchmark=True \ + --use_gpu=True \ + --run_mode=${run_mode} \ + --image_dir=${img_dir} 2>&1 | tee ${log_path}/${model_name}_gpu_runmode_${run_mode}_bs1_infer.log +done + diff --git a/deploy/benchmark/benchmark_quant.sh b/deploy/benchmark/benchmark_quant.sh new file mode 100644 index 000000000..1abdca900 --- /dev/null +++ b/deploy/benchmark/benchmark_quant.sh @@ -0,0 +1,23 @@ +# All rights `PaddleDetection` reserved +#!/bin/bash +model_dir=$1 +model_name=$2 + +export img_dir="demo" +export log_path="output_pipeline" + + +echo "model_dir : ${model_dir}" +echo "img_dir: ${img_dir}" + +# TODO: support batch size>1 +for run_mode in "trt_int8"; do + echo "${model_name} ${model_dir}, run_mode: ${run_mode}" + python deploy/python/infer.py \ + --model_dir=${model_dir} \ + --run_benchmark=True \ + --use_gpu=True \ + --run_mode=${run_mode} \ + --image_dir=${img_dir} 2>&1 | tee ${log_path}/${model_name}_gpu_runmode_${run_mode}_bs1_infer.log +done + diff --git a/deploy/benchmark/log_parser_excel.py b/deploy/benchmark/log_parser_excel.py new file mode 100644 index 000000000..b4d841d9d --- /dev/null +++ b/deploy/benchmark/log_parser_excel.py @@ -0,0 +1,301 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import re +import argparse +import pandas as pd + + +def parse_args(): + """ + parse input args + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--log_path", + type=str, + default="./output_pipeline", + help="benchmark log path") + parser.add_argument( + "--output_name", + type=str, + default="benchmark_excel.xlsx", + help="output excel file name") + parser.add_argument( + "--analysis_trt", dest="analysis_trt", action='store_true') + parser.add_argument( + "--analysis_mkl", dest="analysis_mkl", action='store_true') + return parser.parse_args() + + +def find_all_logs(path_walk): + """ + find all .log files from target dir + """ + for root, ds, files in os.walk(path_walk): + for file_name in files: + if re.match(r'.*.log', file_name): + full_path = os.path.join(root, file_name) + yield file_name, full_path + + +def process_log(file_name): + """ + process log to dict + """ + output_dict = {} + with open(file_name, 'r') as f: + for i, data in enumerate(f.readlines()): + if i == 0: + continue + line_lists = data.split(" ") + + # conf info + if "runtime_device:" in line_lists: + pos_buf = line_lists.index("runtime_device:") + output_dict["runtime_device"] = line_lists[pos_buf + 1].strip() + if "ir_optim:" in line_lists: + pos_buf = line_lists.index("ir_optim:") + output_dict["ir_optim"] = line_lists[pos_buf + 1].strip() + if "enable_memory_optim:" in line_lists: + pos_buf = line_lists.index("enable_memory_optim:") + output_dict["enable_memory_optim"] = line_lists[pos_buf + + 1].strip() + if "enable_tensorrt:" in line_lists: + pos_buf = line_lists.index("enable_tensorrt:") + output_dict["enable_tensorrt"] = line_lists[pos_buf + 1].strip() + if "precision:" in line_lists: + pos_buf = line_lists.index("precision:") + output_dict["precision"] = line_lists[pos_buf + 1].strip() + if "enable_mkldnn:" in line_lists: + pos_buf = line_lists.index("enable_mkldnn:") + output_dict["enable_mkldnn"] = line_lists[pos_buf + 1].strip() + if "cpu_math_library_num_threads:" in line_lists: + pos_buf = line_lists.index("cpu_math_library_num_threads:") + output_dict["cpu_math_library_num_threads"] = line_lists[ + pos_buf + 1].strip() + + # model info + if "model_name:" in line_lists: + pos_buf = line_lists.index("model_name:") + output_dict["model_name"] = list( + filter(None, line_lists[pos_buf + 1].strip().split('/')))[ + -1] + + # data info + if "batch_size:" in line_lists: + pos_buf = line_lists.index("batch_size:") + output_dict["batch_size"] = line_lists[pos_buf + 1].strip() + if "input_shape:" in line_lists: + pos_buf = line_lists.index("input_shape:") + output_dict["input_shape"] = line_lists[pos_buf + 1].strip() + + # perf info + if "cpu_rss(MB):" in line_lists: + pos_buf = line_lists.index("cpu_rss(MB):") + output_dict["cpu_rss(MB)"] = line_lists[pos_buf + 1].strip( + ).split(',')[0] + if "gpu_rss(MB):" in line_lists: + pos_buf = line_lists.index("gpu_rss(MB):") + output_dict["gpu_rss(MB)"] = line_lists[pos_buf + 1].strip( + ).split(',')[0] + if "gpu_util:" in line_lists: + pos_buf = line_lists.index("gpu_util:") + output_dict["gpu_util"] = line_lists[pos_buf + 1].strip().split( + ',')[0] + if "preproce_time(ms):" in line_lists: + pos_buf = line_lists.index("preproce_time(ms):") + output_dict["preproce_time(ms)"] = line_lists[ + pos_buf + 1].strip().split(',')[0] + if "inference_time(ms):" in line_lists: + pos_buf = line_lists.index("inference_time(ms):") + output_dict["inference_time(ms)"] = line_lists[ + pos_buf + 1].strip().split(',')[0] + if "postprocess_time(ms):" in line_lists: + pos_buf = line_lists.index("postprocess_time(ms):") + output_dict["postprocess_time(ms)"] = line_lists[ + pos_buf + 1].strip().split(',')[0] + return output_dict + + +def filter_df_merge(cpu_df, filter_column=None): + """ + process cpu data frame, merge by 'model_name', 'batch_size' + Args: + cpu_df ([type]): [description] + """ + if not filter_column: + raise Exception( + "please assign filter_column for filter_df_merge function") + + df_lists = [] + filter_column_lists = [] + for k, v in cpu_df.groupby(filter_column, dropna=True): + filter_column_lists.append(k) + df_lists.append(v) + final_output_df = df_lists[-1] + + # merge same model + for i in range(len(df_lists) - 1): + left_suffix = cpu_df[filter_column].unique()[0] + right_suffix = df_lists[i][filter_column].unique()[0] + print(left_suffix, right_suffix) + if not pd.isnull(right_suffix): + final_output_df = pd.merge( + final_output_df, + df_lists[i], + how='left', + left_on=['model_name', 'batch_size'], + right_on=['model_name', 'batch_size'], + suffixes=('', '_{0}_{1}'.format(filter_column, right_suffix))) + + # rename default df columns + origin_column_names = list(cpu_df.columns.values) + origin_column_names.remove(filter_column) + suffix = final_output_df[filter_column].unique()[0] + for name in origin_column_names: + final_output_df.rename( + columns={name: "{0}_{1}_{2}".format(name, filter_column, suffix)}, + inplace=True) + final_output_df.rename( + columns={ + filter_column: "{0}_{1}_{2}".format(filter_column, filter_column, + suffix) + }, + inplace=True) + + final_output_df.sort_values( + by=[ + "model_name_{0}_{1}".format(filter_column, suffix), + "batch_size_{0}_{1}".format(filter_column, suffix) + ], + inplace=True) + return final_output_df + + +def trt_perf_analysis(raw_df): + """ + sperate raw dataframe to a list of dataframe + compare tensorrt percision performance + """ + # filter df by gpu, compare tensorrt and gpu + # define default dataframe for gpu performance analysis + gpu_df = raw_df.loc[raw_df['runtime_device'] == 'gpu'] + new_df = filter_df_merge(gpu_df, "precision") + + # calculate qps diff percentail + infer_fp32 = "inference_time(ms)_precision_fp32" + infer_fp16 = "inference_time(ms)_precision_fp16" + infer_int8 = "inference_time(ms)_precision_int8" + new_df["fp32_fp16_diff"] = new_df[[infer_fp32, infer_fp16]].apply( + lambda x: (float(x[infer_fp16]) - float(x[infer_fp32])) / float(x[infer_fp32]), + axis=1) + new_df["fp32_gpu_diff"] = new_df[["inference_time(ms)", infer_fp32]].apply( + lambda x: (float(x[infer_fp32]) - float(x[infer_fp32])) / float(x["inference_time(ms)"]), + axis=1) + new_df["fp16_int8_diff"] = new_df[[infer_fp16, infer_int8]].apply( + lambda x: (float(x[infer_int8]) - float(x[infer_fp16])) / float(x[infer_fp16]), + axis=1) + + return new_df + + +def mkl_perf_analysis(raw_df): + """ + sperate raw dataframe to a list of dataframe + compare mkldnn performance with not enable mkldnn + """ + # filter df by cpu, compare mkl and cpu + # define default dataframe for cpu mkldnn analysis + cpu_df = raw_df.loc[raw_df['runtime_device'] == 'cpu'] + mkl_compare_df = cpu_df.loc[cpu_df['cpu_math_library_num_threads'] == '1'] + thread_compare_df = cpu_df.loc[cpu_df['enable_mkldnn'] == 'True'] + + # define dataframe need to be analyzed + output_mkl_df = filter_df_merge(mkl_compare_df, 'enable_mkldnn') + output_thread_df = filter_df_merge(thread_compare_df, + 'cpu_math_library_num_threads') + + # calculate performance diff percentail + # compare mkl performance with cpu + enable_mkldnn = "inference_time(ms)_enable_mkldnn_True" + disable_mkldnn = "inference_time(ms)_enable_mkldnn_False" + output_mkl_df["mkl_infer_diff"] = output_mkl_df[[ + enable_mkldnn, disable_mkldnn + ]].apply( + lambda x: (float(x[enable_mkldnn]) - float(x[disable_mkldnn])) / float(x[disable_mkldnn]), + axis=1) + cpu_enable_mkldnn = "cpu_rss(MB)_enable_mkldnn_True" + cpu_disable_mkldnn = "cpu_rss(MB)_enable_mkldnn_False" + output_mkl_df["mkl_cpu_rss_diff"] = output_mkl_df[[ + cpu_enable_mkldnn, cpu_disable_mkldnn + ]].apply( + lambda x: (float(x[cpu_enable_mkldnn]) - float(x[cpu_disable_mkldnn])) / float(x[cpu_disable_mkldnn]), + axis=1) + + # compare cpu_multi_thread performance with cpu + num_threads_1 = "inference_time(ms)_cpu_math_library_num_threads_1" + num_threads_6 = "inference_time(ms)_cpu_math_library_num_threads_6" + output_thread_df["mkl_infer_diff"] = output_thread_df[[ + num_threads_6, num_threads_1 + ]].apply( + lambda x: (float(x[num_threads_6]) - float(x[num_threads_1])) / float(x[num_threads_1]), + axis=1) + cpu_num_threads_1 = "cpu_rss(MB)_cpu_math_library_num_threads_1" + cpu_num_threads_6 = "cpu_rss(MB)_cpu_math_library_num_threads_6" + output_thread_df["mkl_cpu_rss_diff"] = output_thread_df[[ + cpu_num_threads_6, cpu_num_threads_1 + ]].apply( + lambda x: (float(x[cpu_num_threads_6]) - float(x[cpu_num_threads_1])) / float(x[cpu_num_threads_1]), + axis=1) + + return output_mkl_df, output_thread_df + + +def main(): + """ + main + """ + args = parse_args() + # create empty DataFrame + origin_df = pd.DataFrame(columns=[ + "model_name", "batch_size", "input_shape", "runtime_device", "ir_optim", + "enable_memory_optim", "enable_tensorrt", "precision", "enable_mkldnn", + "cpu_math_library_num_threads", "preproce_time(ms)", + "inference_time(ms)", "postprocess_time(ms)", "cpu_rss(MB)", + "gpu_rss(MB)", "gpu_util" + ]) + + for file_name, full_path in find_all_logs(args.log_path): + dict_log = process_log(full_path) + origin_df = origin_df.append(dict_log, ignore_index=True) + + raw_df = origin_df.sort_values(by='model_name') + raw_df.sort_values(by=["model_name", "batch_size"], inplace=True) + raw_df.to_excel(args.output_name) + + if args.analysis_trt: + trt_df = trt_perf_analysis(raw_df) + trt_df.to_excel("trt_analysis_{}".format(args.output_name)) + + if args.analysis_mkl: + mkl_df, thread_df = mkl_perf_analysis(raw_df) + mkl_df.to_excel("mkl_enable_analysis_{}".format(args.output_name)) + thread_df.to_excel("mkl_threads_analysis_{}".format(args.output_name)) + + +if __name__ == "__main__": + main() diff --git a/deploy/cpp/docs/Jetson_build.md b/deploy/cpp/docs/Jetson_build.md index d7ece3058..04c1be493 100644 --- a/deploy/cpp/docs/Jetson_build.md +++ b/deploy/cpp/docs/Jetson_build.md @@ -155,16 +155,19 @@ CUDNN_LIB=/usr/lib/aarch64-linux-gnu/ | 参数 | 说明 | | ---- | ---- | | --model_dir | 导出的预测模型所在路径 | -| --image_path | 要预测的图片文件路径 | -| --video_path | 要预测的视频文件路径 | +| --image_file | 要预测的图片文件路径 | +| --image_dir | 要预测的图片文件夹路径 | +| --video_file | 要预测的视频文件路径 | | --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)| | --use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)| | --gpu_id | 指定进行推理的GPU device id(默认值为0)| | --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| | --run_benchmark | 是否重复预测来进行benchmark测速 | | --output_dir | 输出图片所在的文件夹, 默认为output | +| --use_mkldnn | CPU预测中是否开启MKLDNN加速 | +| --cpu_threads | 设置cpu线程数,默认为1 | -**注意**: 如果同时设置了`video_path`和`image_path`,程序仅预测`video_path`。 +**注意**: 优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`。 `样例一`: diff --git a/deploy/cpp/docs/linux_build.md b/deploy/cpp/docs/linux_build.md index 76b961955..41a85a765 100644 --- a/deploy/cpp/docs/linux_build.md +++ b/deploy/cpp/docs/linux_build.md @@ -97,16 +97,19 @@ make | 参数 | 说明 | | ---- | ---- | | --model_dir | 导出的预测模型所在路径 | -| --image_path | 要预测的图片文件路径 | -| --video_path | 要预测的视频文件路径 | +| --image_file | 要预测的图片文件路径 | +| --image_dir | 要预测的图片文件夹路径 | +| --video_file | 要预测的视频文件路径 | | --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)| | --use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)| | --gpu_id | 指定进行推理的GPU device id(默认值为0)| | --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| | --run_benchmark | 是否重复预测来进行benchmark测速 | | --output_dir | 输出图片所在的文件夹, 默认为output | +| --use_mkldnn | CPU预测中是否开启MKLDNN加速 | +| --cpu_threads | 设置cpu线程数,默认为1 | -**注意**: 如果同时设置了`video_path`和`image_path`,程序仅预测`video_path`。 +**注意**: 优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`。 `样例一`: diff --git a/deploy/cpp/docs/windows_vs2019_build.md b/deploy/cpp/docs/windows_vs2019_build.md index 34607b21d..b8a4902c0 100644 --- a/deploy/cpp/docs/windows_vs2019_build.md +++ b/deploy/cpp/docs/windows_vs2019_build.md @@ -92,17 +92,20 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release | 参数 | 说明 | | ---- | ---- | | --model_dir | 导出的预测模型所在路径 | -| --image_path | 要预测的图片文件路径 | -| --video_path | 要预测的视频文件路径 | +| --image_file | 要预测的图片文件路径 | +| --image_dir | 要预测的图片文件夹路径 | +| --video_file | 要预测的视频文件路径 | | --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)| | --use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)| | --gpu_id | 指定进行推理的GPU device id(默认值为0)| | --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| | --run_benchmark | 是否重复预测来进行benchmark测速 | | --output_dir | 输出图片所在的文件夹, 默认为output | +| --use_mkldnn | CPU预测中是否开启MKLDNN加速 | +| --cpu_threads | 设置cpu线程数,默认为1 | **注意**: -(1)如果同时设置了`video_path`和`image_path`,程序仅预测`video_path`。 +(1)优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`。 (2)如果提示找不到`opencv_world346.dll`,把`D:\projects\packages\opencv3_4_6\build\x64\vc14\bin`文件夹下的`opencv_world346.dll`拷贝到`main.exe`文件夹下即可。 diff --git a/deploy/cpp/include/object_detector.h b/deploy/cpp/include/object_detector.h index 4c1846a24..7e22d4d91 100644 --- a/deploy/cpp/include/object_detector.h +++ b/deploy/cpp/include/object_detector.h @@ -58,40 +58,46 @@ class ObjectDetector { public: explicit ObjectDetector(const std::string& model_dir, bool use_gpu=false, + bool use_mkldnn=false, + int cpu_threads=1, const std::string& run_mode="fluid", const int gpu_id=0, bool use_dynamic_shape=false, const int trt_min_shape=1, const int trt_max_shape=1280, - const int trt_opt_shape=640) { + const int trt_opt_shape=640, + bool trt_calib_mode=false) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->cpu_math_library_num_threads_ = cpu_threads; + this->use_mkldnn_ = use_mkldnn; + + this->use_dynamic_shape_ = use_dynamic_shape; + this->trt_min_shape_ = trt_min_shape; + this->trt_max_shape_ = trt_max_shape; + this->trt_opt_shape_ = trt_opt_shape; + this->trt_calib_mode_ = trt_calib_mode; config_.load_config(model_dir); + this->min_subgraph_size_ = config_.min_subgraph_size_; threshold_ = config_.draw_threshold_; image_shape_ = config_.image_shape_; preprocessor_.Init(config_.preprocess_info_, image_shape_); - LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode, gpu_id, - use_dynamic_shape, trt_min_shape, trt_max_shape, trt_opt_shape); + LoadModel(model_dir, 1, run_mode); } // Load Paddle inference model void LoadModel( const std::string& model_dir, - bool use_gpu, - const int min_subgraph_size, const int batch_size = 1, - const std::string& run_mode = "fluid", - const int gpu_id=0, - bool use_dynamic_shape=false, - const int trt_min_shape=1, - const int trt_max_shape=1280, - const int trt_opt_shape=640); + const std::string& run_mode = "fluid"); // Run predictor void Predict(const cv::Mat& im, const double threshold = 0.5, const int warmup = 0, const int repeats = 1, - const bool run_benchmark = false, - std::vector* result = nullptr); + std::vector* result = nullptr, + std::vector* times = nullptr); // Get Model Label list const std::vector& GetLabelList() const { @@ -99,6 +105,16 @@ class ObjectDetector { } private: + bool use_gpu_ = false; + int gpu_id_ = 0; + int cpu_math_library_num_threads_ = 1; + bool use_mkldnn_ = false; + int min_subgraph_size_ = 3; + bool use_dynamic_shape_ = false; + int trt_min_shape_ = 1; + int trt_max_shape_ = 1280; + int trt_opt_shape_ = 640; + bool trt_calib_mode_ = false; // Preprocess image and copy data to input buffer void Preprocess(const cv::Mat& image_mat); // Postprocess result diff --git a/deploy/cpp/src/main.cc b/deploy/cpp/src/main.cc index cd696be0e..74bf3660b 100644 --- a/deploy/cpp/src/main.cc +++ b/deploy/cpp/src/main.cc @@ -14,9 +14,11 @@ #include +#include #include #include #include +#include #include #include @@ -33,20 +35,54 @@ DEFINE_string(model_dir, "", "Path of inference model"); -DEFINE_string(image_path, "", "Path of input image"); -DEFINE_string(video_path, "", "Path of input video"); +DEFINE_string(image_file, "", "Path of input image"); +DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority."); +DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority."); +DEFINE_int32(camera_id, -1, "Device id of camera to predict"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); -DEFINE_bool(use_camera, false, "Use camera or not"); +DEFINE_double(threshold, 0.5, "Threshold of score."); +DEFINE_string(output_dir, "output", "Directory of output visualization files."); DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)"); DEFINE_int32(gpu_id, 0, "Device id of GPU to execute"); -DEFINE_int32(camera_id, -1, "Device id of camera to predict"); DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark"); -DEFINE_double(threshold, 0.5, "Threshold of score."); -DEFINE_string(output_dir, "output", "Directory of output visualization files."); +DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU"); +DEFINE_int32(cpu_threads, 1, "Num of threads with CPU"); DEFINE_bool(use_dynamic_shape, false, "Trt use dynamic shape or not"); DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI"); DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI"); DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI"); +DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True"); + +void PrintBenchmarkLog(std::vector det_time, int img_num){ + LOG(INFO) << "----------------------- Config info -----------------------"; + LOG(INFO) << "runtime_device: " << (FLAGS_use_gpu ? "gpu" : "cpu"); + LOG(INFO) << "ir_optim: " << "True"; + LOG(INFO) << "enable_memory_optim: " << "True"; + int has_trt = FLAGS_run_mode.find("trt"); + if (has_trt >= 0) { + LOG(INFO) << "enable_tensorrt: " << "True"; + std::string precision = FLAGS_run_mode.substr(4, 8); + LOG(INFO) << "precision: " << precision; + } else { + LOG(INFO) << "enable_tensorrt: " << "False"; + LOG(INFO) << "precision: " << "fp32"; + } + LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False"); + LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads; + LOG(INFO) << "----------------------- Data info -----------------------"; + LOG(INFO) << "batch_size: " << 1; + LOG(INFO) << "input_shape: " << "dynamic shape"; + LOG(INFO) << "----------------------- Model info -----------------------"; + FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1); + LOG(INFO) << "model_name: " << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1); + LOG(INFO) << "----------------------- Perf info ------------------------"; + LOG(INFO) << "Total number of predicted data: " << img_num + << " and total time spent(s): " + << std::accumulate(det_time.begin(), det_time.end(), 0); + LOG(INFO) << "preproce_time(ms): " << det_time[0] / img_num + << ", inference_time(ms): " << det_time[1] / img_num + << ", postprocess_time(ms): " << det_time[2]; +} static std::string DirName(const std::string &filepath) { auto pos = filepath.rfind(OS_PATH_SEP); @@ -89,6 +125,37 @@ static void MkDirs(const std::string& path) { MkDir(path); } +void GetAllFiles(const char *dir_name, + std::vector &all_inputs) { + if (NULL == dir_name) { + std::cout << " dir_name is null ! " << std::endl; + return; + } + struct stat s; + lstat(dir_name, &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dir_name is not a valid directory !" << std::endl; + all_inputs.push_back(dir_name); + return; + } else { + struct dirent *filename; // return value for readdir() + DIR *dir; // return value for opendir() + dir = opendir(dir_name); + if (NULL == dir) { + std::cout << "Can not open dir " << dir_name << std::endl; + return; + } + std::cout << "Successfully opened the dir !" << std::endl; + while ((filename = readdir(dir)) != NULL) { + if (strcmp(filename->d_name, ".") == 0 || + strcmp(filename->d_name, "..") == 0) + continue; + all_inputs.push_back(dir_name + std::string("/") + + std::string(filename->d_name)); + } + } +} + void PredictVideo(const std::string& video_path, PaddleDetection::ObjectDetector* det) { // Open video @@ -122,6 +189,7 @@ void PredictVideo(const std::string& video_path, } std::vector result; + std::vector det_times; auto labels = det->GetLabelList(); auto colormap = PaddleDetection::GenerateColorMap(labels.size()); // Capture all frames and do inference @@ -131,7 +199,7 @@ void PredictVideo(const std::string& video_path, if (frame.empty()) { break; } - det->Predict(frame, 0.5, 0, 1, false, &result); + det->Predict(frame, 0.5, 0, 1, &result, &det_times); cv::Mat out_im = PaddleDetection::VisualizeResult( frame, result, labels, colormap); for (const auto& item : result) { @@ -151,55 +219,62 @@ void PredictVideo(const std::string& video_path, video_out.release(); } -void PredictImage(const std::string& image_path, +void PredictImage(const std::vector all_img_list, const double threshold, const bool run_benchmark, PaddleDetection::ObjectDetector* det, const std::string& output_dir = "output") { - // Open input image as an opencv cv::Mat object - cv::Mat im = cv::imread(image_path, 1); - // Store all detected result - std::vector result; - if (run_benchmark) - { - det->Predict(im, threshold, 100, 100, run_benchmark, &result); - }else - { - det->Predict(im, 0.5, 0, 1, run_benchmark, &result); - for (const auto& item : result) { - printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", - item.class_id, - item.confidence, - item.rect[0], - item.rect[1], - item.rect[2], - item.rect[3]); + std::vector det_t = {0, 0, 0}; + for (auto image_file : all_img_list) { + // Open input image as an opencv cv::Mat object + cv::Mat im = cv::imread(image_file, 1); + // Store all detected result + std::vector result; + std::vector det_times; + if (run_benchmark) { + det->Predict(im, threshold, 10, 10, &result, &det_times); + } else { + det->Predict(im, 0.5, 0, 1, &result, &det_times); + for (const auto& item : result) { + printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", + item.class_id, + item.confidence, + item.rect[0], + item.rect[1], + item.rect[2], + item.rect[3]); + } + // Visualization result + auto labels = det->GetLabelList(); + auto colormap = PaddleDetection::GenerateColorMap(labels.size()); + cv::Mat vis_img = PaddleDetection::VisualizeResult( + im, result, labels, colormap); + std::vector compression_params; + compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); + compression_params.push_back(95); + std::string output_path(output_dir); + if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) { + output_path += OS_PATH_SEP; + } + ; + output_path += image_file.substr(image_file.find_last_of('/') + 1); + cv::imwrite(output_path, vis_img, compression_params); + printf("Visualized output saved as %s\n", output_path.c_str()); } - // Visualization result - auto labels = det->GetLabelList(); - auto colormap = PaddleDetection::GenerateColorMap(labels.size()); - cv::Mat vis_img = PaddleDetection::VisualizeResult( - im, result, labels, colormap); - std::vector compression_params; - compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); - compression_params.push_back(95); - std::string output_path(output_dir); - if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) { - output_path += OS_PATH_SEP; - } - output_path += "output.jpg"; - cv::imwrite(output_path, vis_img, compression_params); - printf("Visualized output saved as %s\n", output_path.c_str()); + det_t[0] += det_times[0]; + det_t[1] += det_times[1]; + det_t[2] += det_times[2]; } + PrintBenchmarkLog(det_t, all_img_list.size()); } int main(int argc, char** argv) { // Parsing command-line google::ParseCommandLineFlags(&argc, &argv, true); if (FLAGS_model_dir.empty() - || (FLAGS_image_path.empty() && FLAGS_video_path.empty())) { + || (FLAGS_image_file.empty() && FLAGS_image_dir.empty() && FLAGS_video_file.empty())) { std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ " - << "--image_path=/PATH/TO/INPUT/IMAGE/" << std::endl; + << "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl; return -1; } if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" @@ -208,17 +283,23 @@ int main(int argc, char** argv) { return -1; } // Load model and create a object detector - PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_run_mode, - FLAGS_gpu_id, FLAGS_use_dynamic_shape, FLAGS_trt_min_shape, - FLAGS_trt_max_shape, FLAGS_trt_opt_shape); + PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu, FLAGS_use_mkldnn, + FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_gpu_id, FLAGS_use_dynamic_shape, + FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape, FLAGS_trt_calib_mode); // Do inference on input video or image - if (!FLAGS_video_path.empty() || FLAGS_use_camera) { - PredictVideo(FLAGS_video_path, &det); - } else if (!FLAGS_image_path.empty()) { + if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) { + PredictVideo(FLAGS_video_file, &det); + } else if (!FLAGS_image_file.empty() || !FLAGS_image_dir.empty()) { if (!PathExists(FLAGS_output_dir)) { MkDirs(FLAGS_output_dir); } - PredictImage(FLAGS_image_path, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir); + std::vector all_img_list; + if (!FLAGS_image_file.empty()) { + all_img_list.push_back(FLAGS_image_file); + } else { + GetAllFiles((char *)FLAGS_image_dir.c_str(), all_img_list); + } + PredictImage(all_img_list, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir); } return 0; } diff --git a/deploy/cpp/src/object_detector.cc b/deploy/cpp/src/object_detector.cc index 95b8dbb22..8df375683 100644 --- a/deploy/cpp/src/object_detector.cc +++ b/deploy/cpp/src/object_detector.cc @@ -24,24 +24,16 @@ namespace PaddleDetection { // Load Model and create model predictor void ObjectDetector::LoadModel(const std::string& model_dir, - bool use_gpu, - const int min_subgraph_size, const int batch_size, - const std::string& run_mode, - const int gpu_id, - bool use_dynamic_shape, - const int trt_min_shape, - const int trt_max_shape, - const int trt_opt_shape) { + const std::string& run_mode) { paddle_infer::Config config; std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel"; std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams"; config.SetModel(prog_file, params_file); - if (use_gpu) { - config.EnableUseGpu(200, gpu_id); + if (this->use_gpu_) { + config.EnableUseGpu(200, this->gpu_id_); config.SwitchIrOptim(true); // use tensorrt - bool use_calib_mode = false; if (run_mode != "fluid") { auto precision = paddle_infer::Config::Precision::kFloat32; if (run_mode == "trt_fp32") { @@ -52,7 +44,6 @@ void ObjectDetector::LoadModel(const std::string& model_dir, } else if (run_mode == "trt_int8") { precision = paddle_infer::Config::Precision::kInt8; - use_calib_mode = true; } else { printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'"); } @@ -60,17 +51,17 @@ void ObjectDetector::LoadModel(const std::string& model_dir, config.EnableTensorRtEngine( 1 << 30, batch_size, - min_subgraph_size, + this->min_subgraph_size_, precision, false, - use_calib_mode); + this->trt_calib_mode_); // set use dynamic shape - if (use_dynamic_shape) { + if (this->use_dynamic_shape_) { // set DynamicShsape for image tensor - const std::vector min_input_shape = {1, 3, trt_min_shape, trt_min_shape}; - const std::vector max_input_shape = {1, 3, trt_max_shape, trt_max_shape}; - const std::vector opt_input_shape = {1, 3, trt_opt_shape, trt_opt_shape}; + const std::vector min_input_shape = {1, 3, this->trt_min_shape_, this->trt_min_shape_}; + const std::vector max_input_shape = {1, 3, this->trt_max_shape_, this->trt_max_shape_}; + const std::vector opt_input_shape = {1, 3, this->trt_opt_shape_, this->trt_opt_shape_}; const std::map> map_min_input_shape = {{"image", min_input_shape}}; const std::map> map_max_input_shape = {{"image", max_input_shape}}; const std::map> map_opt_input_shape = {{"image", opt_input_shape}}; @@ -84,8 +75,15 @@ void ObjectDetector::LoadModel(const std::string& model_dir, } else { config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); } config.SwitchUseFeedFetchOps(false); + config.SwitchIrOptim(true); config.DisableGlogInfo(); // Memory optimization config.EnableMemoryOptim(); @@ -189,8 +187,9 @@ void ObjectDetector::Predict(const cv::Mat& im, const double threshold, const int warmup, const int repeats, - const bool run_benchmark, - std::vector* result) { + std::vector* result, + std::vector* times) { + auto preprocess_start = std::chrono::steady_clock::now(); // Preprocess image Preprocess(im); // Prepare input tensor @@ -210,6 +209,7 @@ void ObjectDetector::Predict(const cv::Mat& im, in_tensor->CopyFromCpu(inputs_.scale_factor_.data()); } } + auto preprocess_end = std::chrono::steady_clock::now(); // Run predictor for (int i = 0; i < warmup; i++) { @@ -231,7 +231,7 @@ void ObjectDetector::Predict(const cv::Mat& im, out_tensor->CopyToCpu(output_data_.data()); } - auto start = std::chrono::steady_clock::now(); + auto inference_start = std::chrono::steady_clock::now(); for (int i = 0; i < repeats; i++) { predictor_->Run(); @@ -251,14 +251,18 @@ void ObjectDetector::Predict(const cv::Mat& im, output_data_.resize(output_size); out_tensor->CopyToCpu(output_data_.data()); } - auto end = std::chrono::steady_clock::now(); - std::chrono::duration diff = end - start; - float ms = diff.count() / repeats * 1000; - printf("Inference: %f ms per batch image\n", ms); + auto inference_end = std::chrono::steady_clock::now(); + auto postprocess_start = std::chrono::steady_clock::now(); // Postprocessing result - if(!run_benchmark) { - Postprocess(im, result); - } + Postprocess(im, result); + auto postprocess_end = std::chrono::steady_clock::now(); + + std::chrono::duration preprocess_diff = preprocess_end - preprocess_start; + times->push_back(double(preprocess_diff.count() * 1000)); + std::chrono::duration inference_diff = inference_end - inference_start; + times->push_back(double(inference_diff.count() / repeats * 1000)); + std::chrono::duration postprocess_diff = postprocess_end - postprocess_start; + times->push_back(double(postprocess_diff.count() * 1000)); } std::vector GenerateColorMap(int num_class) { diff --git a/deploy/python/README.md b/deploy/python/README.md index e0a5a32b0..786756ec4 100644 --- a/deploy/python/README.md +++ b/deploy/python/README.md @@ -1,17 +1,12 @@ # Python端预测部署 -Python预测可以使用`tools/infer.py`,此种方式依赖PaddleDetection源码;也可以使用本篇教程预测方式,先将模型导出,使用一个独立的文件进行预测。 - - -本篇教程使用AnalysisPredictor对[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)进行高性能预测。 - 在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。 主要包含两个步骤: - 导出预测模型 -- 基于Python的预测 +- 基于Python进行预测 ## 1. 导出预测模型 @@ -22,16 +17,14 @@ PaddleDetection在训练过程包括网络的前向和优化器相关参数, ## 2. 基于python的预测 ### 2.1 安装依赖 - - `PaddlePaddle`的安装: - 请点击[官方安装文档](https://paddlepaddle.org.cn/install/quick) 选择适合的方式,版本为2.0rc1以上即可 - - 切换到`PaddleDetection`代码库根目录,执行`pip install -r requirements.txt`安装其它依赖 +- `PaddlePaddle`的安装: 请点击[官方安装文档](https://paddlepaddle.org.cn/install/quick) 选择适合的版本进行安装,要求PaddlePaddle>=2.0.1以上。 +- 切换到`PaddleDetection`代码库根目录,执行`pip install -r requirements.txt`安装其它依赖。 ### 2.2 执行预测程序 在终端输入以下命令进行预测: ```bash -python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/image ---use_gpu=(False/True) +python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/image --use_gpu=(False/True) ``` 参数说明如下: @@ -40,41 +33,19 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/ |-------|-------|----------| | --model_dir | Yes|上述导出的模型路径 | | --image_file | Option |需要预测的图片 | +| --image_dir | Option | 要预测的图片文件夹路径 | | --video_file | Option |需要预测的视频 | | --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4| -| --use_gpu |No|是否GPU,默认为False| -| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| -| --threshold |No|预测得分的阈值,默认为0.5| -| --output_dir |No|可视化结果保存的根目录,默认为output/| -| --run_benchmark |No|是否运行benchmark,同时需指定--image_file| +| --use_gpu | No |是否GPU,默认为False| +| --run_mode | No |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| +| --threshold | No|预测得分的阈值,默认为0.5| +| --output_dir | No|可视化结果保存的根目录,默认为output/| +| --run_benchmark | No| 是否运行benchmark,同时需指定`--image_file`或`--image_dir` | +| --use_mkldnn | No | CPU预测中是否开启MKLDNN加速 | +| --cpu_threads | No| 设置cpu线程数,默认为1 | 说明: +- 参数优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`。 - run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。 -- PaddlePaddle默认的GPU安装包(<=1.7),不支持基于TensorRT进行预测,如果想基于TensorRT加速预测,需要自行编译,详细可参考[预测库编译教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/paddle_tensorrt_infer.html)。 - -## 3. 部署性能对比测试 -对比AnalysisPredictor相对Executor的推理速度 - -### 3.1 测试环境: - -- CUDA 9.0 -- CUDNN 7.5 -- PaddlePaddle 1.71 -- GPU: Tesla P40 - -### 3.2 测试方式: - -- Batch Size=1 -- 去掉前100轮warmup时间,测试100轮的平均时间,单位ms/image,只计算模型运行时间,不包括数据的处理和拷贝。 - - -### 3.3 测试结果 - -|模型 | AnalysisPredictor | Executor | 输入| -|---|----|---|---| -| YOLOv3-MobileNetv1 | 15.20 | 19.54 | 608*608 -| faster_rcnn_r50_fpn_1x | 50.05 | 69.58 |800*1088 -| faster_rcnn_r50_1x | 326.11 | 347.22 | 800*1067 -| mask_rcnn_r50_fpn_1x | 67.49 | 91.02 | 800*1088 -| mask_rcnn_r50_1x | 326.11 | 350.94 | 800*1067 +- 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/paddle_tensorrt_infer.html)。 diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 5bfd54554..d378aaf10 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -13,21 +13,22 @@ # limitations under the License. import os -import argparse import time import yaml -import ast +import glob from functools import reduce from PIL import Image import cv2 import numpy as np import paddle -from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride -from visualize import visualize_box_mask from paddle.inference import Config from paddle.inference import create_predictor +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride +from visualize import visualize_box_mask +from utils import argsparser, Timer, get_current_memory_mb, LoggerHelper + # Global dictionary SUPPORT_MODELS = { 'YOLO', @@ -63,7 +64,7 @@ class Detector(object): trt_min_shape=1, trt_max_shape=1280, trt_opt_shape=640, - threshold=0.5): + trt_calib_mode=False): self.pred_config = pred_config self.predictor = load_predictor( model_dir, @@ -73,7 +74,10 @@ class Detector(object): use_dynamic_shape=use_dynamic_shape, trt_min_shape=trt_min_shape, trt_max_shape=trt_max_shape, - trt_opt_shape=trt_opt_shape) + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 def preprocess(self, im): preprocess_ops = [] @@ -102,12 +106,7 @@ class Detector(object): results['masks'] = np_masks return results - def predict(self, - image, - threshold=0.5, - warmup=0, - repeats=1, - run_benchmark=False): + def predict(self, image, threshold=0.5, warmup=0, repeats=1): ''' Args: image (str/np.ndarray): path of image/ np.ndarray read by cv2 @@ -118,13 +117,14 @@ class Detector(object): MaskRCNN's results include 'masks': np.ndarray: shape: [N, im_h, im_w] ''' + self.det_times.preprocess_time.start() inputs = self.preprocess(image) np_boxes, np_masks = None, None input_names = self.predictor.get_input_names() for i in range(len(input_names)): input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor.copy_from_cpu(inputs[input_names[i]]) - + self.det_times.preprocess_time.end() for i in range(warmup): self.predictor.run() output_names = self.predictor.get_output_names() @@ -134,7 +134,7 @@ class Detector(object): masks_tensor = self.predictor.get_output_handle(output_names[2]) np_masks = masks_tensor.copy_to_cpu() - t1 = time.time() + self.det_times.inference_time.start() for i in range(repeats): self.predictor.run() output_names = self.predictor.get_output_names() @@ -143,20 +143,18 @@ class Detector(object): if self.pred_config.mask: masks_tensor = self.predictor.get_output_handle(output_names[2]) np_masks = masks_tensor.copy_to_cpu() - t2 = time.time() - ms = (t2 - t1) * 1000.0 / repeats - print("Inference: {} ms per batch image".format(ms)) + self.det_times.inference_time.end(repeats=repeats) - # do not perform postprocess in benchmark mode + self.det_times.postprocess_time.start() results = [] - if not run_benchmark: - if reduce(lambda x, y: x * y, np_boxes.shape) < 6: - print('[WARNNING] No object detected.') - results = {'boxes': np.array([])} - else: - results = self.postprocess( - np_boxes, np_masks, inputs, threshold=threshold) - + if reduce(lambda x, y: x * y, np_boxes.shape) < 6: + print('[WARNNING] No object detected.') + results = {'boxes': np.array([])} + else: + results = self.postprocess( + np_boxes, np_masks, inputs, threshold=threshold) + self.det_times.postprocess_time.end() + self.det_times.img_num += 1 return results @@ -183,7 +181,7 @@ class DetectorSOLOv2(Detector): trt_min_shape=1, trt_max_shape=1280, trt_opt_shape=640, - threshold=0.5): + trt_calib_mode=False): self.pred_config = pred_config self.predictor = load_predictor( model_dir, @@ -193,14 +191,11 @@ class DetectorSOLOv2(Detector): use_dynamic_shape=use_dynamic_shape, trt_min_shape=trt_min_shape, trt_max_shape=trt_max_shape, - trt_opt_shape=trt_opt_shape) - - def predict(self, - image, - threshold=0.5, - warmup=0, - repeats=1, - run_benchmark=False): + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode) + self.det_times = Timer() + + def predict(self, image, threshold=0.5, warmup=0, repeats=1): ''' Args: image (str/np.ndarray): path of image/ np.ndarray read by cv2 @@ -210,13 +205,14 @@ class DetectorSOLOv2(Detector): 'cate_label': label of segm, shape:[N] 'cate_score': confidence score of segm, shape:[N] ''' + self.det_times.preprocess_time.start() inputs = self.preprocess(image) np_label, np_score, np_segms = None, None, None input_names = self.predictor.get_input_names() for i in range(len(input_names)): input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor.copy_from_cpu(inputs[input_names[i]]) - + self.det_times.preprocess_time.end() for i in range(warmup): self.predictor.run() output_names = self.predictor.get_output_names() @@ -227,7 +223,7 @@ class DetectorSOLOv2(Detector): np_segms = self.predictor.get_output_handle(output_names[ 3]).copy_to_cpu() - t1 = time.time() + self.det_times.inference_time.start() for i in range(repeats): self.predictor.run() output_names = self.predictor.get_output_names() @@ -237,15 +233,10 @@ class DetectorSOLOv2(Detector): 2]).copy_to_cpu() np_segms = self.predictor.get_output_handle(output_names[ 3]).copy_to_cpu() - t2 = time.time() - ms = (t2 - t1) * 1000.0 / repeats - print("Inference: {} ms per batch image".format(ms)) + self.det_times.inference_time.end(repeats=repeats) + self.det_times.img_num += 1 - # do not perform postprocess in benchmark mode - results = [] - if not run_benchmark: - return dict(segm=np_segms, label=np_label, score=np_score) - return results + return dict(segm=np_segms, label=np_label, score=np_score) def create_inputs(im, im_info): @@ -316,7 +307,8 @@ def load_predictor(model_dir, use_dynamic_shape=False, trt_min_shape=1, trt_max_shape=1280, - trt_opt_shape=640): + trt_opt_shape=640, + trt_calib_mode=False): """set AnalysisConfig, generate AnalysisPredictor Args: model_dir (str): root path of __model__ and __params__ @@ -326,6 +318,8 @@ def load_predictor(model_dir, trt_min_shape (int): min shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True Returns: predictor (PaddlePredictor): AnalysisPredictor Raises: @@ -335,7 +329,6 @@ def load_predictor(model_dir, raise ValueError( "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}" .format(run_mode, use_gpu)) - use_calib_mode = True if run_mode == 'trt_int8' else False config = Config( os.path.join(model_dir, 'model.pdmodel'), os.path.join(model_dir, 'model.pdiparams')) @@ -351,6 +344,17 @@ def load_predictor(model_dir, config.switch_ir_optim(True) else: config.disable_gpu() + config.set_cpu_math_library_num_threads(FLAGS.cpu_threads) + if FLAGS.enable_mkldnn: + try: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + except Exception as e: + print( + "The current environment does not support `mkldnn`, so disable mkldnn." + ) + pass if run_mode in precision_map.keys(): config.enable_tensorrt_engine( @@ -359,10 +363,9 @@ def load_predictor(model_dir, min_subgraph_size=min_subgraph_size, precision_mode=precision_map[run_mode], use_static=False, - use_calib_mode=use_calib_mode) + use_calib_mode=trt_calib_mode) if use_dynamic_shape: - print('use_dynamic_shape') min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} @@ -380,6 +383,37 @@ def load_predictor(model_dir, return predictor +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + print("Found {} inference images in total.".format(len(images))) + + return images + + def visualize(image_file, results, labels, output_dir='output/', threshold=0.5): # visualize the predict result im = visualize_box_mask(image_file, results, labels, threshold=threshold) @@ -398,22 +432,23 @@ def print_arguments(args): print('------------------------------------------') -def predict_image(detector): - if FLAGS.run_benchmark: - detector.predict( - FLAGS.image_file, - FLAGS.threshold, - warmup=100, - repeats=100, - run_benchmark=True) - else: - results = detector.predict(FLAGS.image_file, FLAGS.threshold) - visualize( - FLAGS.image_file, - results, - detector.pred_config.labels, - output_dir=FLAGS.output_dir, - threshold=FLAGS.threshold) +def predict_image(detector, image_list): + for i, img_file in enumerate(image_list): + if FLAGS.run_benchmark: + detector.predict(img_file, FLAGS.threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + detector.cpu_mem += cm + detector.gpu_mem += gm + detector.gpu_util += gu + print('Test iter {}, file name:{}'.format(i, img_file)) + else: + results = detector.predict(img_file, FLAGS.threshold) + visualize( + img_file, + results, + detector.pred_config.labels, + output_dir=FLAGS.output_dir, + threshold=FLAGS.threshold) def predict_video(detector, camera_id): @@ -465,7 +500,8 @@ def main(): use_dynamic_shape=FLAGS.use_dynamic_shape, trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape) + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode) if pred_config.arch == 'SOLOv2': detector = DetectorSOLOv2( pred_config, @@ -475,77 +511,33 @@ def main(): use_dynamic_shape=FLAGS.use_dynamic_shape, trt_min_shape=FLAGS.trt_min_shape, trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape) - # predict from image - if FLAGS.image_file != '': - predict_image(detector) + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode) + # predict from video file or camera video stream - if FLAGS.video_file != '' or FLAGS.camera_id != -1: + if FLAGS.video_file is not None or FLAGS.camera_id != -1: predict_video(detector, FLAGS.camera_id) + else: + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + predict_image(detector, img_list) + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + else: + mems = { + 'cpu_rss': detector.cpu_mem / len(img_list), + 'gpu_rss': detector.gpu_mem / len(img_list), + 'gpu_util': detector.gpu_util * 100 / len(img_list) + } + det_logger = LoggerHelper( + FLAGS, detector.det_times.report(average=True), mems) + det_logger.report() if __name__ == '__main__': paddle.enable_static() - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--model_dir", - type=str, - default=None, - help=("Directory include:'model.pdiparams', 'model.pdmodel', " - "'infer_cfg.yml', created by tools/export_model.py."), - required=True) - parser.add_argument( - "--image_file", type=str, default='', help="Path of image file.") - parser.add_argument( - "--video_file", type=str, default='', help="Path of video file.") - parser.add_argument( - "--camera_id", - type=int, - default=-1, - help="device id of camera to predict.") - parser.add_argument( - "--run_mode", - type=str, - default='fluid', - help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)") - parser.add_argument( - "--use_gpu", - type=ast.literal_eval, - default=False, - help="Whether to predict with GPU.") - parser.add_argument( - "--run_benchmark", - type=ast.literal_eval, - default=False, - help="Whether to predict a image_file repeatedly for benchmark") - parser.add_argument( - "--threshold", type=float, default=0.5, help="Threshold of score.") - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Directory of output visualization files.") - parser.add_argument( - "--use_dynamic_shape", - type=ast.literal_eval, - default=False, - help="Dynamic_shape for TensorRT.") - parser.add_argument( - "--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.") - parser.add_argument( - "--trt_max_shape", - type=int, - default=1280, - help="max_shape for TensorRT.") - parser.add_argument( - "--trt_opt_shape", - type=int, - default=640, - help="opt_shape for TensorRT.") - + parser = argsparser() FLAGS = parser.parse_args() print_arguments(FLAGS) - if FLAGS.image_file != '' and FLAGS.video_file != '': - assert "Cannot predict image and video at the same time" main() diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py index 3d0c1b9b1..700926ea8 100644 --- a/deploy/python/preprocess.py +++ b/deploy/python/preprocess.py @@ -71,9 +71,6 @@ class Resize(object): assert self.target_size[0] > 0 and self.target_size[1] > 0 im_channel = im.shape[2] im_scale_y, im_scale_x = self.generate_scale(im) - # set image_shape - im_info['input_shape'][1] = int(im_scale_y * im.shape[0]) - im_info['input_shape'][2] = int(im_scale_x * im.shape[1]) im = cv2.resize( im, None, @@ -84,6 +81,14 @@ class Resize(object): im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') im_info['scale_factor'] = np.array( [im_scale_y, im_scale_x]).astype('float32') + # padding im when image_shape fixed by infer_cfg.yml + if self.keep_ratio and im_info['input_shape'][1] != -1: + max_size = im_info['input_shape'][1] + padding_im = np.zeros( + (max_size, max_size, im_channel), dtype=np.float32) + im_h, im_w = im.shape[:2] + padding_im[:im_h, :im_w, :] = im + im = padding_im return im, im_info def generate_scale(self, im): diff --git a/deploy/python/utils.py b/deploy/python/utils.py new file mode 100644 index 000000000..66dbdbd0a --- /dev/null +++ b/deploy/python/utils.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import ast +import argparse + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_dir", + type=str, + default=None, + help=("Directory include:'model.pdiparams', 'model.pdmodel', " + "'infer_cfg.yml', created by tools/export_model.py."), + required=True) + parser.add_argument( + "--image_file", type=str, default=None, help="Path of image file.") + parser.add_argument( + "--image_dir", + type=str, + default=None, + help="Dir of image file, `image_file` has a higher priority.") + parser.add_argument( + "--video_file", + type=str, + default=None, + help="Path of video file, `video_file` or `camera_id` has a highest priority." + ) + parser.add_argument( + "--camera_id", + type=int, + default=-1, + help="device id of camera to predict.") + parser.add_argument( + "--threshold", type=float, default=0.5, help="Threshold of score.") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory of output visualization files.") + parser.add_argument( + "--run_mode", + type=str, + default='fluid', + help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)") + parser.add_argument( + "--use_gpu", + type=ast.literal_eval, + default=False, + help="Whether to predict with GPU.") + parser.add_argument( + "--run_benchmark", + type=ast.literal_eval, + default=False, + help="Whether to predict a image_file repeatedly for benchmark") + parser.add_argument( + "--enable_mkldnn", + type=ast.literal_eval, + default=False, + help="Whether use mkldnn with CPU.") + parser.add_argument( + "--cpu_threads", type=int, default=1, help="Num of threads with CPU.") + parser.add_argument( + "--use_dynamic_shape", + type=ast.literal_eval, + default=False, + help="Dynamic_shape for TensorRT.") + parser.add_argument( + "--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.") + parser.add_argument( + "--trt_max_shape", + type=int, + default=1280, + help="max_shape for TensorRT.") + parser.add_argument( + "--trt_opt_shape", + type=int, + default=640, + help="opt_shape for TensorRT.") + parser.add_argument( + "--trt_calib_mode", + type=bool, + default=False, + help="If the model is produced by TRT offline quantitative " + "calibration, trt_calib_mode need to set True.") + + return parser + + +class Times(object): + def __init__(self): + self.time = 0. + # start time + self.st = 0. + # end time + self.et = 0. + + def start(self): + self.st = time.time() + + def end(self, repeats=1, accumulative=True): + self.et = time.time() + if accumulative: + self.time += (self.et - self.st) / repeats + else: + self.time = (self.et - self.st) / repeats + + def reset(self): + self.time = 0. + self.st = 0. + self.et = 0. + + def value(self): + return round(self.time, 4) + + +class Timer(Times): + def __init__(self): + super(Timer, self).__init__() + self.preprocess_time = Times() + self.inference_time = Times() + self.postprocess_time = Times() + self.img_num = 0 + + def info(self, average=False): + total_time = self.preprocess_time.value() + self.inference_time.value( + ) + self.postprocess_time.value() + total_time = round(total_time, 4) + print("------------------ Inference Time Info ----------------------") + print("total_time(ms): {}, img_num: {}".format(total_time * 1000, + self.img_num)) + preprocess_time = round(self.preprocess_time.value() / self.img_num, + 4) if average else self.preprocess_time.value() + postprocess_time = round( + self.postprocess_time.value() / self.img_num, + 4) if average else self.postprocess_time.value() + inference_time = round(self.inference_time.value() / self.img_num, + 4) if average else self.inference_time.value() + + average_latency = total_time / self.img_num + print("average latency time(ms): {:.2f}, QPS: {:2f}".format( + average_latency * 1000, 1 / average_latency)) + print( + "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}". + format(preprocess_time * 1000, inference_time * 1000, + postprocess_time * 1000)) + + def report(self, average=False): + dic = {} + dic['preprocess_time'] = round( + self.preprocess_time.value() / self.img_num, + 4) if average else self.preprocess_time.value() + dic['postprocess_time'] = round( + self.postprocess_time.value() / self.img_num, + 4) if average else self.postprocess_time.value() + dic['inference_time'] = round( + self.inference_time.value() / self.img_num, + 4) if average else self.inference_time.value() + dic['img_num'] = self.img_num + total_time = self.preprocess_time.value() + self.inference_time.value( + ) + self.postprocess_time.value() + dic['total_time'] = round(total_time, 4) + return dic + + +def get_current_memory_mb(): + """ + It is used to Obtain the memory usage of the CPU and GPU during the running of the program. + And this function Current program is time-consuming. + """ + import pynvml + import psutil + import GPUtil + gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', 0) + + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + cpu_mem = info.uss / 1024. / 1024. + gpu_mem = 0 + gpu_percent = 0 + gpus = GPUtil.getGPUs() + if gpu_id is not None and len(gpus) > 0: + gpu_percent = gpus[gpu_id].load + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem = meminfo.used / 1024. / 1024. + return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) + + +class LoggerHelper(object): + def __init__(self, args, times, mem_info=None): + """ + args: utility.parse_args() + times: The Timer class + """ + self.args = args + self.times = times + self.model_name = args.model_dir.strip('/').split('/')[-1] + self.batch_size = 1 + self.shape = "dynamic shape" + if args.run_mode == 'fluid': + self.precision = "fp32" + self.use_tensorrt = False + else: + self.precision = args.run_mode.split('_')[-1] + self.use_tensorrt = True + + self.device = "gpu" if args.use_gpu else "cpu" + self.preprocess_time = round(times['preprocess_time'], 4) + self.inference_time = round(times['inference_time'], 4) + self.postprocess_time = round(times['postprocess_time'], 4) + self.data_num = times['img_num'] + self.total_time = round(times['total_time'], 4) + self.mem_info = {"cpu_rss": 0, "gpu_rss": 0, "gpu_util": 0} + if mem_info is not None: + self.mem_info = mem_info + + def report(self): + print("\n") + print("----------------------- Config info -----------------------") + print("runtime_device:", self.device) + print("ir_optim:", True) + print("enable_memory_optim:", True) + print("enable_tensorrt:", self.use_tensorrt) + print("precision:", self.precision) + print("enable_mkldnn:", self.args.enable_mkldnn) + print("cpu_math_library_num_threads:", self.args.cpu_threads) + + print("----------------------- Model info ----------------------") + print("model_name:", self.model_name) + + print("------------------------ Data info ----------------------") + print("batch_size:", self.batch_size) + print("input_shape:", self.shape) + + print("----------------------- Perf info -----------------------") + print("[cpu_rss(MB): {} gpu_rss(MB): {}, gpu_util: {}%".format( + round(self.mem_info['cpu_rss'], 4), + round(self.mem_info['gpu_rss'], 4), + round(self.mem_info['gpu_util'], 2))) + print("total number of predicted data: {} and total time spent(s): {}". + format(self.data_num, self.total_time)) + print( + "preproce_time(ms): {}, inference_time(ms): {}, postprocess_time(ms): {}". + format(self.preprocess_time * 1000, self.inference_time * 1000, + self.postprocess_time * 1000)) diff --git a/deploy/imgs/input_shape.png b/docs/images/input_shape.png similarity index 100% rename from deploy/imgs/input_shape.png rename to docs/images/input_shape.png diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index e6763184c..87f6e2499 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -57,6 +57,11 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): for st in sample_transforms[1:]: for key, value in st.items(): p = {'type': key} + if key == 'Resize': + if value.get('keep_ratio', False) and int(image_shape[1]) != -1: + max_size = max(image_shape[1:]) + image_shape = [3, max_size, max_size] + value['target_size'] = image_shape[1:] p.update(value) preprocess_list.append(p) batch_transforms = reader_cfg.get('batch_transforms', None) -- GitLab