提交 b2e2bb98 编写于 作者: L LDOUBLEV

fix problems responding to inference

上级 5b4675e0
...@@ -15,39 +15,63 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ...@@ -15,39 +15,63 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
## 文档教程 ## 文档教程
- [快速安装](./doc/installation.md) - [快速安装](./doc/installation.md)
- [快速开始]()
- [文本识别模型训练/评估/预测](./doc/detection.md) - [文本识别模型训练/评估/预测](./doc/detection.md)
- [文本预测模型训练/评估/预测](./doc/recognition.md) - [文本预测模型训练/评估/预测](./doc/recognition.md)
- [基于inference model预测](./doc/) - [基于inference model预测](./doc/)
### **快速开始**
下载inference模型
```
# 创建inference模型保存目录
mkdir inference && cd inference && mkdir det && mkdir rec
# 下载检测inference模型
wget -P ./inference/det 检测inference模型链接
# 下载识别inference模型
wget -P ./inferencee/rec 识别inference模型链接
```
实现文本检测、识别串联推理,预测$image_dir$指定的单张图像:
```
export PYTHONPATH=.
python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
```
在执行预测时,通过参数det_model_dir以及rec_model_dir设置存储inference 模型的路径。
实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像:
```
python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
```
## 文本检测算法: ## 文本检测算法:
PaddleOCR开源的文本检测算法列表: PaddleOCR开源的文本检测算法列表:
- [x] [EAST](https://arxiv.org/abs/1704.03155) - [x] [EAST](https://arxiv.org/abs/1704.03155)
- [x] [DB](https://arxiv.org/abs/1911.08947) - [x] [DB](https://arxiv.org/abs/1911.08947)
- [x] [SAST](https://arxiv.org/abs/1908.05498) - [ ] [SAST](https://arxiv.org/abs/1908.05498)
- []
算法效果: 算法效果:
|模型|骨干网络|Hmean| |模型|骨干网络|Hmean|
|-|-|-| |-|-|-|
|EAST^[1]^|ResNet50_vd|85.85%| |EAST|ResNet50_vd|85.85%|
|EAST^[1]^|MobileNetV3|79.08%| |EAST|MobileNetV3|79.08%|
|DB^[2]^|ResNet50_vd|83.30%| |DB|ResNet50_vd|83.30%|
|DB^[2]^|MobileNetV3|73.00%| |DB|MobileNetV3|73.00%|
PaddleOCR文本检测算法的训练与使用请参考[文档](./doc/detection.md) PaddleOCR文本检测算法的训练与使用请参考[文档](./doc/detection.md)
## 文本识别算法: ## 文本识别算法:
PaddleOCR开源的文本识别算法列表: PaddleOCR开源的文本识别算法列表:
- [CRNN](https://arxiv.org/abs/1507.05717) - [x] [CRNN](https://arxiv.org/abs/1507.05717)
- [Rosetta](https://arxiv.org/abs/1910.05085) - [x] [DTRB](https://arxiv.org/abs/1904.01906)
- [STAR-Net](http://www.bmva.org/bmvc/2016/papers/paper043/index.html) - [ ] [Rosetta](https://arxiv.org/abs/1910.05085)
- [RARE](https://arxiv.org/abs/1603.03915v1) - [ ] [STAR-Net](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)
- [SRN]((https://arxiv.org/abs/2003.12294))(百度自研) - [ ] [RARE](https://arxiv.org/abs/1603.03915v1)
- [ ] [SRN]((https://arxiv.org/abs/2003.12294))(百度自研)
算法效果如下表所示,精度指标是在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上的评测结果的平均值。 算法效果如下表所示,精度指标是在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上的评测结果的平均值。
...@@ -67,7 +91,7 @@ PaddleOCR文本识别算法的训练与使用请参考[文档](./doc/recognition ...@@ -67,7 +91,7 @@ PaddleOCR文本识别算法的训练与使用请参考[文档](./doc/recognition
## TODO ## TODO
**端到端OCR算法** **端到端OCR算法**
PaddleOCR即将开源百度自研端对端OCR模型[End2End-PSL](https://arxiv.org/abs/1909.07808),敬请关注。 PaddleOCR即将开源百度自研端对端OCR模型[End2End-PSL](https://arxiv.org/abs/1909.07808),敬请关注。
- End2End-PSL (comming soon) - [ ] End2End-PSL (comming soon)
......
...@@ -14,6 +14,7 @@ Global: ...@@ -14,6 +14,7 @@ Global:
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
checkpoints: checkpoints:
save_res_path: ./output/predicts_db.txt save_res_path: ./output/predicts_db.txt
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.det_model,DetModel function: ppocr.modeling.architectures.det_model,DetModel
......
# 基于inference model的推理
inference 模型(fluid.io.save_inference_model保存的模型)
一般是模型训练完成后保存的固化模型,多用于预测部署。
训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越。
PaddleOCR提供了将checkpoints转换成inference model的实现。
## 文本检测模型推理
将文本检测模型训练过程中保存的模型,转换成inference model,可以使用如下命令:
```
python tools/export_model.py -c configs/det/det_db_mv3.yml -o Global.checkpoints="./output/best_accuracy" \
Global.save_inference_dir="./inference/det/"
```
推理模型保存在$./inference/det/model$, $./inference/det/params$
使用保存的inference model实现在单张图像上的预测:
```
python tools/infer/predict_det.py --image_dir="/demo.jpg" --det_model_dir="./inference/det/"
```
## 文本识别模型推理
将文本识别模型训练过程中保存的模型,转换成inference model,可以使用如下命令:
```
python tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints="./output/best_accuracy" \
Global.save_inference_dir="./inference/rec/"
```
推理模型保存在$./inference/rec/model$, $./inference/rec/params$
使用保存的inference model实现在单张图像上的预测:
```
python tools/infer/predict_rec.py --image_dir="/demo.jpg" --rec_model_dir="./inference/rec/"
```
## 文本检测、识别串联推理
实现文本检测、识别串联推理,预测$image_dir$指定的单张图像:
```
python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
```
实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像:
```
python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
```
...@@ -22,7 +22,7 @@ import string ...@@ -22,7 +22,7 @@ import string
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import create_module from ppocr.utils.utility import create_module
from tools.infer.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import time import time
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os
def initial_logger(): def initial_logger():
...@@ -55,6 +56,23 @@ def get_check_reader_params(mode): ...@@ -55,6 +56,23 @@ def get_check_reader_params(mode):
return check_params return check_params
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(img_file, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
return imgs_lists
from paddle import fluid from paddle import fluid
......
...@@ -71,14 +71,19 @@ def main(): ...@@ -71,14 +71,19 @@ def main():
init_model(config, eval_program, exe) init_model(config, eval_program, exe)
save_inference_dir = config['Global']['save_inference_dir']
if not os.path.exists(save_inference_dir):
os.makedirs(save_inference_dir)
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname="./output/", dirname=save_inference_dir,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
main_program=eval_program, main_program=eval_program,
target_vars=target_vars, target_vars=target_vars,
executor=exe, executor=exe,
model_filename='model', model_filename='model',
params_filename='params') params_filename='params')
print("inference model saved in {}/model and {}/params".format(
save_inference_dir, save_inference_dir))
print("save success, output_name_list:", fetches_var_name) print("save success, output_name_list:", fetches_var_name)
......
...@@ -61,23 +61,6 @@ def parse_args(): ...@@ -61,23 +61,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(img_file, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
return imgs_lists
def create_predictor(args, mode): def create_predictor(args, mode):
if mode == "det": if mode == "det":
model_dir = args.det_model_dir model_dir = args.det_model_dir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册