diff --git a/deploy/paddle2onnx/readme.md b/deploy/paddle2onnx/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..02c069b7ac69110b63c53cd88e6ab5a223174972 --- /dev/null +++ b/deploy/paddle2onnx/readme.md @@ -0,0 +1,72 @@ +# paddle2onnx 模型转化与预测 + +本章节介绍 PaddleOCR 模型如何转化为 ONNX 模型,并基于 ONNX 引擎预测。 + +## 1. 环境准备 + +需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境 + +### Paddle2ONNX +Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。 +更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md) + +- 安装 Paddle2ONNX +``` +python3.7 -m pip install paddle2onnx +``` + +- 安装 ONNX +``` +# 建议安装 1.4.0 版本,可根据环境更换版本号 +python3.7 -m pip install onnxruntime==1.4.0 +``` + +## 2. 模型转换 + + +- Paddle 模型下载 + +有两种方式获取Paddle静态图模型:在 [model_list](../../doc/doc_ch/models_list.md) 中下载PaddleOCR提供的预测模型; +参考[模型导出说明](../../doc/doc_ch/inference.md#训练模型转inference模型)把训练好的权重转为 inference_model。 + +以 ppocr 检测模型为例: + +``` +wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar +cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd .. +``` + +- 模型转换 + +使用 Paddle2ONNX 将Paddle静态图模型转换为ONNX模型格式: + +``` +paddle2onnx --model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ \ +--model_filename=inference.pdmodel \ +--params_filename=inference.pdiparams \ +--save_file=./inference/det_mobile_onnx/model.onnx \ +--opset_version=10 \ +--enable_onnx_checker=True +``` + +执行完毕后,ONNX 模型会被保存在 `./inference/det_mobile_onnx/` 路径下 + +## 3. onnx 预测 + +以检测模型为例,使用 ONNX 预测可执行如下命令: + +``` +python3.7 ../../tools/infer/predict_det.py --use_gpu=False --use_onnx=True \ +--det_model_dir=./inference/det_mobile_onnx/model.onnx \ +--image_dir=../../doc/imgs/1.jpg +``` + +执行命令后在终端会打印出预测的检测框坐标,并在 `./inference_results/` 下保存可视化结果。 + +``` +root INFO: 1.jpg [[[291, 295], [334, 292], [348, 844], [305, 847]], [[344, 296], [379, 294], [387, 669], [353, 671]]] +The predict time of ../../doc/imgs/1.jpg: 0.06162881851196289 +The visualized image saved in ./inference_results/det_res_1.jpg +``` + +* 注意:ONNX暂时不支持变长预测,因为需要将输入resize到固定输入,预测结果可能与直接使用Paddle预测有细微不同。 diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index af883de86c798babe6ca1616710c0e13546e1045..1ffb8f340305249d5f2c8aeb2fef37dc8f590eff 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -68,4 +68,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | |SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SEED| Aster_Resnet | 85.2% | rec_resnet_stn_bilstm_att | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar)| + PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/test_tipc/configs/ppocr_det_mobile_params.txt b/test_tipc/configs/ppocr_det_mobile_params.txt index e3e17a960c10eec24331b8f233d1a67dc1d5a554..d7e9cf95c2e9b4b2e18265e5f8b4a65cd6bdf518 100644 --- a/test_tipc/configs/ppocr_det_mobile_params.txt +++ b/test_tipc/configs/ppocr_det_mobile_params.txt @@ -108,3 +108,15 @@ infer_model:./models/ch_ppocr_mobile_v2.0_det_opt.nb|./models/ch_ppocr_mobile_v2 --config_dir:./config.txt --rec_dict_dir:./ppocr_keys_v1.txt --benchmark:True +===========================paddle2onnx_params=========================== +2onnx: paddle2onnx +--model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--save_file:./inference/det_mobile_onnx/model.onnx +--opset_version:10 +--enable_onnx_checker:True +inference:tools/infer/predict_det.py +--use_gpu:False +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ \ No newline at end of file diff --git a/test_tipc/docs/test_paddle2onnx.md b/test_tipc/docs/test_paddle2onnx.md new file mode 100644 index 0000000000000000000000000000000000000000..5d784c5e93c3a93d00c256004de582dcbf357c45 --- /dev/null +++ b/test_tipc/docs/test_paddle2onnx.md @@ -0,0 +1,47 @@ +# Paddle2onnx预测功能测试 + +PaddleServing预测功能测试的主程序为`test_paddle2onnx.sh`,可以测试Paddle2ONNX的模型转化功能,并验证正确性。 + +## 1. 测试结论汇总 + +基于训练是否使用量化,进行本测试的模型可以分为`正常模型`和`量化模型`,这两类模型对应的Paddle2ONNX预测功能汇总如下: + +| 模型类型 |device | +| ---- | ---- | +| 正常模型 | GPU | +| 正常模型 | CPU | +| 量化模型 | GPU | +| 量化模型 | CPU | + +## 2. 测试流程 +### 2.1 功能测试 +先运行`prepare.sh`准备数据和模型,然后运行`test_paddle2onnx.sh`进行测试,最终在```test_tipc/output```目录下生成`paddle2onnx_infer_*.log`后缀的日志文件。 + +```shell +bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt "paddle2onnx_infer" + +# 用法: +bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/ppocr_det_mobile_params.txt +``` + +#### 运行结果 + +各测试的运行情况会打印在 `test_tipc/output/results_paddle2onnx.log` 中: +运行成功时会输出: + +``` +Run successfully with command - paddle2onnx --model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_mobile_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! +Run successfully with command - python test_tipc/onnx_inference/predict_det.py --use_gpu=False --image_dir=./inference/ch_det_data_50/all-sum-510/ --det_model_dir=./inference/det_mobile_onnx/model.onnx 2>&1 ! +``` + +运行失败时会输出: + +``` +Run failed with command - paddle2onnx --model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_mobile_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! +... +``` + + +## 3. 更多教程 + +本文档为功能测试用,更详细的Paddle2onnx预测使用教程请参考:[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) diff --git a/test_tipc/docs/test_serving.md b/test_tipc/docs/test_serving.md index ce508af09ce505466b78db57304904fa84c733af..f63d6c7107ce92807c53d81a22a582b09178a712 100644 --- a/test_tipc/docs/test_serving.md +++ b/test_tipc/docs/test_serving.md @@ -4,7 +4,7 @@ PaddleServing预测功能测试的主程序为`test_serving.sh`,可以测试 ## 1. 测试结论汇总 -基于训练是否使用量化,进行本测试的模型可以分为`正常模型`和`量化模型`,这两类模型对应的C++预测功能汇总如下: +基于训练是否使用量化,进行本测试的模型可以分为`正常模型`和`量化模型`,这两类模型对应的Serving预测功能汇总如下: | 模型类型 |device | batchsize | tensorrt | mkldnn | cpu多线程 | | ---- | ---- | ---- | :----: | :----: | :----: | diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index c9a38e01c93f9aac6cd06f4207f23f9c9f451c71..f3ad242538a9471af237c804eae343da06e2b9dd 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -201,3 +201,20 @@ if [ ${MODE} = "lite_infer" ];then tar -cf test_lite.tar ./test_lite && cp test_lite.tar ${current_dir} && cd ${current_dir} fi + +if [ ${MODE} = "paddle2onnx_infer" ];then + # prepare serving env + python_name=$(func_parser_value "${lines[2]}") + ${python_name} -m pip install install paddle2onnx + ${python_name} -m pip install onnxruntime==1.4.0 + # wget model + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar + # wget data + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar + cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && tar xf rec_inference.tar && cd ../ + +fi diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh new file mode 100644 index 0000000000000000000000000000000000000000..5dc6e65ec81e6b8674877fc686c8b3650ce93a59 --- /dev/null +++ b/test_tipc/test_paddle2onnx.sh @@ -0,0 +1,76 @@ +#!/bin/bash +source test_tipc/common_func.sh + +FILENAME=$1 + +dataline=$(cat ${FILENAME}) +lines=(${dataline}) +# common params +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") + + +# parser params +dataline=$(awk 'NR==111, NR==123{print}' $FILENAME) +IFS=$'\n' +lines=(${dataline}) + +# parser paddle2onnx +padlle2onnx_cmd=$(func_parser_value "${lines[1]}") +infer_model_dir_key=$(func_parser_key "${lines[2]}") +infer_model_dir_value=$(func_parser_value "${lines[2]}") +model_filename_key=$(func_parser_key "${lines[3]}") +model_filename_value=$(func_parser_value "${lines[3]}") +params_filename_key=$(func_parser_key "${lines[4]}") +params_filename_value=$(func_parser_value "${lines[4]}") +save_file_key=$(func_parser_key "${lines[5]}") +save_file_value=$(func_parser_value "${lines[5]}") +opset_version_key=$(func_parser_key "${lines[6]}") +opset_version_value=$(func_parser_value "${lines[6]}") +enable_onnx_checker_key=$(func_parser_key "${lines[7]}") +enable_onnx_checker_value=$(func_parser_value "${lines[7]}") +# parser onnx inference +inference_py=$(func_parser_value "${lines[8]}") +use_gpu_key=$(func_parser_key "${lines[9]}") +use_gpu_value=$(func_parser_value "${lines[9]}") +det_model_key=$(func_parser_key "${lines[10]}") +image_dir_key=$(func_parser_key "${lines[11]}") +image_dir_value=$(func_parser_value "${lines[11]}") + + +LOG_PATH="./test_tipc/output" +mkdir -p ./test_tipc/output +status_log="${LOG_PATH}/results_paddle2onnx.log" + + +function func_paddle2onnx(){ + IFS='|' + _script=$1 + + # paddle2onnx + _save_log_path="${LOG_PATH}/paddle2onnx_infer_cpu.log" + set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_save_model=$(func_set_params "${save_file_key}" "${save_file_value}") + set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") + set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker}" + eval $trans_model_cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${trans_model_cmd}" "${status_log}" + # python inference + set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu_value}") + set_model_dir=$(func_set_params "${det_model_key}" "${save_file_value}") + set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") + infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " + eval $infer_model_cmd + status_check $last_status "${infer_model_cmd}" "${status_log}" +} + + +echo "################### run test ###################" + +export Count=0 +IFS="|" +func_paddle2onnx \ No newline at end of file diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 1c68494861e60b4aaef541a4e247071944cf420c..a25cac2600e67667badc76c648c1fcda12981a0f 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -47,6 +47,7 @@ class TextClassifier(object): self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, _ = \ utility.create_predictor(args, 'cls', logger) + self.use_onnx = args.use_onnx def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape @@ -100,10 +101,16 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - prob_out = self.output_tensors[0].copy_to_cpu() - self.predictor.try_shrink_memory() + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + prob_out = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + prob_out = self.output_tensors[0].copy_to_cpu() + self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b24ad2bbb504caf1f262b4e47625348ce32d6fce..5dfe8d648f06f6382e8e101a6002f7f1b7441323 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -38,6 +38,7 @@ class TextDetector(object): def __init__(self, args): self.args = args self.det_algorithm = args.det_algorithm + self.use_onnx = args.use_onnx pre_process_list = [{ 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, @@ -100,7 +101,12 @@ class TextDetector(object): else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) - + if self.use_onnx: + pre_process_list[0] = { + 'DetResizeForTest': { + 'image_shape': [640, 640] + } + } self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( @@ -198,15 +204,19 @@ class TextDetector(object): if self.args.benchmark: self.autolog.times.stamp() - - self.input_tensor.copy_from_cpu(img) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.args.benchmark: - self.autolog.times.stamp() + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.args.benchmark: + self.autolog.times.stamp() preds = {} if self.det_algorithm == "EAST": diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 6cc91b56a31708986ffbe35b649a67c4385c22b2..41982e3403b11dd4a1893f89af11a9201e0e15d7 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -73,6 +73,7 @@ class TextRecognizer(object): self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) self.benchmark = args.benchmark + self.use_onnx = args.use_onnx if args.benchmark: import auto_log pid = os.getpid() @@ -107,6 +108,8 @@ class TextRecognizer(object): assert imgC == img.shape[2] imgW = int((32 * max_wh_ratio)) + if self.use_onnx: + imgW = 100 h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: @@ -296,51 +299,72 @@ class TextRecognizer(object): gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, ] - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle(input_names[ - i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - preds = {"predict": outputs[2]} + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = {"predict": outputs[2]} + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = {"predict": outputs[2]} elif self.rec_algorithm == "SAR": valid_ratios = np.concatenate(valid_ratios) inputs = [ norm_img_batch, valid_ratios, ] - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle(input_names[ - i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - preds = outputs[0] - else: - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - if len(outputs) != 1: - preds = outputs + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() preds = outputs[0] + else: + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + if len(outputs) != 1: + preds = outputs + else: + preds = outputs[0] rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 41a3c0f14b6378751a367a3709ad7943ee981a4e..a3cac647982b77f5ee54d8681b07e677987d9ccb 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -121,6 +121,7 @@ def init_args(): parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--show_log", type=str2bool, default=True) + parser.add_argument("--use_onnx", type=str2bool, default=False) return parser @@ -144,152 +145,163 @@ def create_predictor(args, mode, logger): if model_dir is None: logger.info("not find {} model file path {}".format(mode, model_dir)) sys.exit(0) - model_file_path = model_dir + "/inference.pdmodel" - params_file_path = model_dir + "/inference.pdiparams" - if not os.path.exists(model_file_path): - raise ValueError("not find model file path {}".format(model_file_path)) - if not os.path.exists(params_file_path): - raise ValueError("not find params file path {}".format( - params_file_path)) - - config = inference.Config(model_file_path, params_file_path) - - if hasattr(args, 'precision'): - if args.precision == "fp16" and args.use_tensorrt: - precision = inference.PrecisionType.Half - elif args.precision == "int8": - precision = inference.PrecisionType.Int8 - else: - precision = inference.PrecisionType.Float32 + if args.use_onnx: + import onnxruntime as ort + model_file_path = model_dir + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + sess = ort.InferenceSession(model_file_path) + return sess, sess.get_inputs()[0], None, None + else: - precision = inference.PrecisionType.Float32 - - if args.use_gpu: - gpu_id = get_infer_gpuid() - if gpu_id is None: - raise ValueError( - "Not found GPU in current device. Please check your device or set args.use_gpu as False" - ) - config.enable_use_gpu(args.gpu_mem, 0) - if args.use_tensorrt: - config.enable_tensorrt_engine( - precision_mode=precision, - max_batch_size=args.max_batch_size, - min_subgraph_size=args.min_subgraph_size) - # skip the minmum trt subgraph - if mode == "det": - min_input_shape = { - "x": [1, 3, 50, 50], - "conv2d_92.tmp_0": [1, 120, 20, 20], - "conv2d_91.tmp_0": [1, 24, 10, 10], - "conv2d_59.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10], - "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20], - "conv2d_124.tmp_0": [1, 256, 20, 20], - "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20], - "elementwise_add_7": [1, 56, 2, 2], - "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2] - } - max_input_shape = { - "x": [1, 3, 2000, 2000], - "conv2d_92.tmp_0": [1, 120, 400, 400], - "conv2d_91.tmp_0": [1, 24, 200, 200], - "conv2d_59.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200], - "conv2d_124.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400], - "elementwise_add_7": [1, 56, 400, 400], - "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400] - } - opt_input_shape = { - "x": [1, 3, 640, 640], - "conv2d_92.tmp_0": [1, 120, 160, 160], - "conv2d_91.tmp_0": [1, 24, 80, 80], - "conv2d_59.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80], - "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160], - "conv2d_124.tmp_0": [1, 256, 160, 160], - "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160], - "elementwise_add_7": [1, 56, 40, 40], - "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40] - } - min_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], - "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], - "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20] - } - max_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], - "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], - "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400] - } - opt_pact_shape = { - "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], - "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], - "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160] - } - min_input_shape.update(min_pact_shape) - max_input_shape.update(max_pact_shape) - opt_input_shape.update(opt_pact_shape) - elif mode == "rec": - min_input_shape = {"x": [1, 3, 32, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} - elif mode == "cls": - min_input_shape = {"x": [1, 3, 48, 10]} - max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} - opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + model_file_path = model_dir + "/inference.pdmodel" + params_file_path = model_dir + "/inference.pdiparams" + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + if not os.path.exists(params_file_path): + raise ValueError("not find params file path {}".format( + params_file_path)) + + config = inference.Config(model_file_path, params_file_path) + + if hasattr(args, 'precision'): + if args.precision == "fp16" and args.use_tensorrt: + precision = inference.PrecisionType.Half + elif args.precision == "int8": + precision = inference.PrecisionType.Int8 + else: + precision = inference.PrecisionType.Float32 else: - min_input_shape = {"x": [1, 3, 10, 10]} - max_input_shape = {"x": [1, 3, 1000, 1000]} - opt_input_shape = {"x": [1, 3, 500, 500]} - config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, - opt_input_shape) + precision = inference.PrecisionType.Float32 + + if args.use_gpu: + gpu_id = get_infer_gpuid() + if gpu_id is None: + raise ValueError( + "Not found GPU in current device. Please check your device or set args.use_gpu as False" + ) + config.enable_use_gpu(args.gpu_mem, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + precision_mode=precision, + max_batch_size=args.max_batch_size, + min_subgraph_size=args.min_subgraph_size) + # skip the minmum trt subgraph + if mode == "det": + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_92.tmp_0": [1, 120, 20, 20], + "conv2d_91.tmp_0": [1, 24, 10, 10], + "conv2d_59.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10], + "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20], + "conv2d_124.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20], + "elementwise_add_7": [1, 56, 2, 2], + "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2] + } + max_input_shape = { + "x": [1, 3, 2000, 2000], + "conv2d_92.tmp_0": [1, 120, 400, 400], + "conv2d_91.tmp_0": [1, 24, 200, 200], + "conv2d_59.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200], + "conv2d_124.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400], + "elementwise_add_7": [1, 56, 400, 400], + "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_92.tmp_0": [1, 120, 160, 160], + "conv2d_91.tmp_0": [1, 24, 80, 80], + "conv2d_59.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80], + "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160], + "conv2d_124.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160], + "elementwise_add_7": [1, 56, 40, 40], + "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40] + } + min_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20] + } + max_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400] + } + opt_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160] + } + min_input_shape.update(min_pact_shape) + max_input_shape.update(max_pact_shape) + opt_input_shape.update(opt_pact_shape) + elif mode == "rec": + min_input_shape = {"x": [1, 3, 32, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} + elif mode == "cls": + min_input_shape = {"x": [1, 3, 48, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + else: + min_input_shape = {"x": [1, 3, 10, 10]} + max_input_shape = {"x": [1, 3, 1000, 1000]} + opt_input_shape = {"x": [1, 3, 500, 500]} + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) - else: - config.disable_gpu() - if hasattr(args, "cpu_threads"): - config.set_cpu_math_library_num_threads(args.cpu_threads) else: - # default cpu threads as 10 - config.set_cpu_math_library_num_threads(10) - if args.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() - if args.precision == "fp16": - config.enable_mkldnn_bfloat16() - # enable memory optim - config.enable_memory_optim() - config.disable_glog_info() - - config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") - if mode == 'table': - config.delete_pass("fc_fuse_pass") # not supported for table - config.switch_use_feed_fetch_ops(False) - config.switch_ir_optim(True) - - # create predictor - predictor = inference.create_predictor(config) - input_names = predictor.get_input_names() - for name in input_names: - input_tensor = predictor.get_input_handle(name) - output_names = predictor.get_output_names() - output_tensors = [] - for output_name in output_names: - output_tensor = predictor.get_output_handle(output_name) - output_tensors.append(output_tensor) - return predictor, input_tensor, output_tensors, config + config.disable_gpu() + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if args.precision == "fp16": + config.enable_mkldnn_bfloat16() + # enable memory optim + config.enable_memory_optim() + config.disable_glog_info() + + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + if mode == 'table': + config.delete_pass("fc_fuse_pass") # not supported for table + config.switch_use_feed_fetch_ops(False) + config.switch_ir_optim(True) + + # create predictor + predictor = inference.create_predictor(config) + input_names = predictor.get_input_names() + for name in input_names: + input_tensor = predictor.get_input_handle(name) + output_names = predictor.get_output_names() + output_tensors = [] + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + return predictor, input_tensor, output_tensors, config def get_infer_gpuid():