From b2e2bb98fb8a46898723281e0e69b18d0e7286dc Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 12 May 2020 21:12:52 +0800 Subject: [PATCH] fix problems responding to inference --- README.md | 54 +++++++++++++++++++-------- configs/det/det_db_mv3.yml | 1 + doc/inference.md | 58 +++++++++++++++++++++++++++++ ppocr/data/det/dataset_traversal.py | 2 +- ppocr/utils/utility.py | 18 +++++++++ tools/export_model.py | 7 +++- tools/infer/utility.py | 17 --------- 7 files changed, 123 insertions(+), 34 deletions(-) create mode 100644 doc/inference.md diff --git a/README.md b/README.md index 2b889c28..6a9adfa6 100644 --- a/README.md +++ b/README.md @@ -15,39 +15,63 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ## 文档教程 - [快速安装](./doc/installation.md) -- [快速开始]() - [文本识别模型训练/评估/预测](./doc/detection.md) - [文本预测模型训练/评估/预测](./doc/recognition.md) - [基于inference model预测](./doc/) +### **快速开始** + +下载inference模型 +``` +# 创建inference模型保存目录 +mkdir inference && cd inference && mkdir det && mkdir rec +# 下载检测inference模型 +wget -P ./inference/det 检测inference模型链接 +# 下载识别inference模型 +wget -P ./inferencee/rec 识别inference模型链接 +``` + +实现文本检测、识别串联推理,预测$image_dir$指定的单张图像: +``` +export PYTHONPATH=. +python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` +在执行预测时,通过参数det_model_dir以及rec_model_dir设置存储inference 模型的路径。 + +实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像: +``` +python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` + + ## 文本检测算法: PaddleOCR开源的文本检测算法列表: -- [x] [EAST](https://arxiv.org/abs/1704.03155) -- [x] [DB](https://arxiv.org/abs/1911.08947) -- [x] [SAST](https://arxiv.org/abs/1908.05498) -- [] +- [x] [EAST](https://arxiv.org/abs/1704.03155) +- [x] [DB](https://arxiv.org/abs/1911.08947) +- [ ] [SAST](https://arxiv.org/abs/1908.05498) 算法效果: |模型|骨干网络|Hmean| |-|-|-| -|EAST^[1]^|ResNet50_vd|85.85%| -|EAST^[1]^|MobileNetV3|79.08%| -|DB^[2]^|ResNet50_vd|83.30%| -|DB^[2]^|MobileNetV3|73.00%| +|EAST|ResNet50_vd|85.85%| +|EAST|MobileNetV3|79.08%| +|DB|ResNet50_vd|83.30%| +|DB|MobileNetV3|73.00%| PaddleOCR文本检测算法的训练与使用请参考[文档](./doc/detection.md)。 ## 文本识别算法: PaddleOCR开源的文本识别算法列表: -- [CRNN](https://arxiv.org/abs/1507.05717) -- [Rosetta](https://arxiv.org/abs/1910.05085) -- [STAR-Net](http://www.bmva.org/bmvc/2016/papers/paper043/index.html) -- [RARE](https://arxiv.org/abs/1603.03915v1) -- [SRN]((https://arxiv.org/abs/2003.12294))(百度自研) +- [x] [CRNN](https://arxiv.org/abs/1507.05717) +- [x] [DTRB](https://arxiv.org/abs/1904.01906) +- [ ] [Rosetta](https://arxiv.org/abs/1910.05085) +- [ ] [STAR-Net](http://www.bmva.org/bmvc/2016/papers/paper043/index.html) +- [ ] [RARE](https://arxiv.org/abs/1603.03915v1) +- [ ] [SRN]((https://arxiv.org/abs/2003.12294))(百度自研) 算法效果如下表所示,精度指标是在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上的评测结果的平均值。 @@ -67,7 +91,7 @@ PaddleOCR文本识别算法的训练与使用请参考[文档](./doc/recognition ## TODO **端到端OCR算法** PaddleOCR即将开源百度自研端对端OCR模型[End2End-PSL](https://arxiv.org/abs/1909.07808),敬请关注。 -- End2End-PSL (comming soon) +- [ ] End2End-PSL (comming soon) diff --git a/configs/det/det_db_mv3.yml b/configs/det/det_db_mv3.yml index 197b1204..45e2ee17 100755 --- a/configs/det/det_db_mv3.yml +++ b/configs/det/det_db_mv3.yml @@ -14,6 +14,7 @@ Global: pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ checkpoints: save_res_path: ./output/predicts_db.txt + save_inference_dir: Architecture: function: ppocr.modeling.architectures.det_model,DetModel diff --git a/doc/inference.md b/doc/inference.md new file mode 100644 index 00000000..d44a96d7 --- /dev/null +++ b/doc/inference.md @@ -0,0 +1,58 @@ + +# 基于inference model的推理 + +inference 模型(fluid.io.save_inference_model保存的模型) +一般是模型训练完成后保存的固化模型,多用于预测部署。 +训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。 +与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越。 + +PaddleOCR提供了将checkpoints转换成inference model的实现。 + + +## 文本检测模型推理 + +将文本检测模型训练过程中保存的模型,转换成inference model,可以使用如下命令: + +``` +python tools/export_model.py -c configs/det/det_db_mv3.yml -o Global.checkpoints="./output/best_accuracy" \ + Global.save_inference_dir="./inference/det/" +``` + +推理模型保存在$./inference/det/model$, $./inference/det/params$ + +使用保存的inference model实现在单张图像上的预测: + +``` +python tools/infer/predict_det.py --image_dir="/demo.jpg" --det_model_dir="./inference/det/" +``` + + +## 文本识别模型推理 + +将文本识别模型训练过程中保存的模型,转换成inference model,可以使用如下命令: + +``` +python tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints="./output/best_accuracy" \ + Global.save_inference_dir="./inference/rec/" +``` + +推理模型保存在$./inference/rec/model$, $./inference/rec/params$ + +使用保存的inference model实现在单张图像上的预测: + +``` +python tools/infer/predict_rec.py --image_dir="/demo.jpg" --rec_model_dir="./inference/rec/" +``` + +## 文本检测、识别串联推理 + +实现文本检测、识别串联推理,预测$image_dir$指定的单张图像: +``` +python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` + +实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像: + +``` +python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 2e68d91d..0feedeeb 100755 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -22,7 +22,7 @@ import string from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.utility import create_module -from tools.infer.utility import get_image_file_list +from ppocr.utils.utility import get_image_file_list import time diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 6a81465c..c824d440 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os def initial_logger(): @@ -55,6 +56,23 @@ def get_check_reader_params(mode): return check_params +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists + + from paddle import fluid diff --git a/tools/export_model.py b/tools/export_model.py index 6c924f3f..11a2744b 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -71,14 +71,19 @@ def main(): init_model(config, eval_program, exe) + save_inference_dir = config['Global']['save_inference_dir'] + if not os.path.exists(save_inference_dir): + os.makedirs(save_inference_dir) fluid.io.save_inference_model( - dirname="./output/", + dirname=save_inference_dir, feeded_var_names=feeded_var_names, main_program=eval_program, target_vars=target_vars, executor=exe, model_filename='model', params_filename='params') + print("inference model saved in {}/model and {}/params".format( + save_inference_dir, save_inference_dir)) print("save success, output_name_list:", fetches_var_name) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 01477a5c..6d56a990 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -61,23 +61,6 @@ def parse_args(): return parser.parse_args() -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] - if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - if single_file.split('.')[-1] in img_end: - imgs_lists.append(os.path.join(img_file, single_file)) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - return imgs_lists - - def create_predictor(args, mode): if mode == "det": model_dir = args.det_model_dir -- GitLab