diff --git a/configs/det/det_db_mv3.yml b/configs/det/det_db_mv3.yml index a41c901e75da275cbc7cd25209c1fb3cada2a595..45e2ee17b72207d98d06169f29a9d3d16ff0557b 100755 --- a/configs/det/det_db_mv3.yml +++ b/configs/det/det_db_mv3.yml @@ -12,7 +12,9 @@ Global: image_shape: [3, 640, 640] reader_yml: ./configs/det/det_db_icdar15_reader.yml pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ + checkpoints: save_res_path: ./output/predicts_db.txt + save_inference_dir: Architecture: function: ppocr.modeling.architectures.det_model,DetModel diff --git a/doc/README.md b/doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6a9adfa608aacbaa897a3625ebd6295c6234f860 --- /dev/null +++ b/doc/README.md @@ -0,0 +1,151 @@ + +# 简介 +PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。 + +## 特性: +- 超轻量级模型 + - (检测模型4.1M + 识别模型4.5M = 8.6M) +- 支持竖排文字识别 + - (单模型同时支持横排和竖排文字识别) +- 支持长文本识别 +- 支持中英文数字组合识别 +- 提供训练代码 +- 支持模型部署 + + +## 文档教程 +- [快速安装](./doc/installation.md) +- [文本识别模型训练/评估/预测](./doc/detection.md) +- [文本预测模型训练/评估/预测](./doc/recognition.md) +- [基于inference model预测](./doc/) + +### **快速开始** + +下载inference模型 +``` +# 创建inference模型保存目录 +mkdir inference && cd inference && mkdir det && mkdir rec +# 下载检测inference模型 +wget -P ./inference/det 检测inference模型链接 +# 下载识别inference模型 +wget -P ./inferencee/rec 识别inference模型链接 +``` + +实现文本检测、识别串联推理,预测$image_dir$指定的单张图像: +``` +export PYTHONPATH=. +python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` +在执行预测时,通过参数det_model_dir以及rec_model_dir设置存储inference 模型的路径。 + +实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像: +``` +python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` + + + +## 文本检测算法: + +PaddleOCR开源的文本检测算法列表: +- [x] [EAST](https://arxiv.org/abs/1704.03155) +- [x] [DB](https://arxiv.org/abs/1911.08947) +- [ ] [SAST](https://arxiv.org/abs/1908.05498) + + +算法效果: +|模型|骨干网络|Hmean| +|-|-|-| +|EAST|ResNet50_vd|85.85%| +|EAST|MobileNetV3|79.08%| +|DB|ResNet50_vd|83.30%| +|DB|MobileNetV3|73.00%| + +PaddleOCR文本检测算法的训练与使用请参考[文档](./doc/detection.md)。 + +## 文本识别算法: + +PaddleOCR开源的文本识别算法列表: +- [x] [CRNN](https://arxiv.org/abs/1507.05717) +- [x] [DTRB](https://arxiv.org/abs/1904.01906) +- [ ] [Rosetta](https://arxiv.org/abs/1910.05085) +- [ ] [STAR-Net](http://www.bmva.org/bmvc/2016/papers/paper043/index.html) +- [ ] [RARE](https://arxiv.org/abs/1603.03915v1) +- [ ] [SRN]((https://arxiv.org/abs/2003.12294))(百度自研) + +算法效果如下表所示,精度指标是在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上的评测结果的平均值。 + +|模型|骨干网络|ACC| +|-|-|-| +|Rosetta|Resnet34_vd|80.24%| +|Rosetta|MobileNetV3|78.16%| +|CRNN|Resnet34_vd|82.20%| +|CRNN|MobileNetV3|79.37%| +|STAR-Net|Resnet34_vd|83.93%| +|STAR-Net|MobileNetV3|81.56%| +|RARE|Resnet34_vd|84.90%| +|RARE|MobileNetV3|83.32%| + +PaddleOCR文本识别算法的训练与使用请参考[文档](./doc/recognition.md)。 + +## TODO +**端到端OCR算法** +PaddleOCR即将开源百度自研端对端OCR模型[End2End-PSL](https://arxiv.org/abs/1909.07808),敬请关注。 +- [ ] End2End-PSL (comming soon) + + + +# 参考文献 +``` +1. EAST: +@inproceedings{zhou2017east, + title={EAST: an efficient and accurate scene text detector}, + author={Zhou, Xinyu and Yao, Cong and Wen, He and Wang, Yuzhi and Zhou, Shuchang and He, Weiran and Liang, Jiajun}, + booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition}, + pages={5551--5560}, + year={2017} +} + +2. DB: +@article{liao2019real, + title={Real-time Scene Text Detection with Differentiable Binarization}, + author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang}, + journal={arXiv preprint arXiv:1911.08947}, + year={2019} +} + +3. DTRB: +@inproceedings{baek2019wrong, + title={What is wrong with scene text recognition model comparisons? dataset and model analysis}, + author={Baek, Jeonghun and Kim, Geewook and Lee, Junyeop and Park, Sungrae and Han, Dongyoon and Yun, Sangdoo and Oh, Seong Joon and Lee, Hwalsuk}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={4715--4723}, + year={2019} +} + +4. SAST: +@inproceedings{wang2019single, + title={A Single-Shot Arbitrarily-Shaped Text Detector based on Context Attended Multi-Task Learning}, + author={Wang, Pengfei and Zhang, Chengquan and Qi, Fei and Huang, Zuming and En, Mengyi and Han, Junyu and Liu, Jingtuo and Ding, Errui and Shi, Guangming}, + booktitle={Proceedings of the 27th ACM International Conference on Multimedia}, + pages={1277--1285}, + year={2019} +} + +5. SRN: +@article{yu2020towards, + title={Towards Accurate Scene Text Recognition with Semantic Reasoning Networks}, + author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Han, Junyu and Liu, Jingtuo and Ding, Errui}, + journal={arXiv preprint arXiv:2003.12294}, + year={2020} +} + +6. end2end-psl: +@inproceedings{sun2019chinese, + title={Chinese Street View Text: Large-scale Chinese Text Reading with Partially Supervised Learning}, + author={Sun, Yipeng and Liu, Jiaming and Liu, Wei and Han, Junyu and Ding, Errui and Liu, Jingtuo}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={9086--9095}, + year={2019} +} +``` diff --git a/doc/detection.md b/doc/detection.md new file mode 100644 index 0000000000000000000000000000000000000000..5e5501103777deb8de541d45c0ec02e6d108236f --- /dev/null +++ b/doc/detection.md @@ -0,0 +1,79 @@ +# 文字检测 + +本节以icdar15数据集为例,介绍PaddleOCR中检测模型的训练、评估与测试。 + +## 数据准备 +icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。 + +将下载到的数据集解压到工作目录下,假设解压在/PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件 +,您可以通过wget的方式进行下载。 +``` +wget -P /PaddleOCR/train_data/ 训练标注文件链接 +wget -P /PaddleOCR/train_data/ 测试标注文件链接 +``` + +解压数据集和下载标注文件后,/PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是: +``` +/PaddleOCR/train_data/ + └─ icdar_c4_train_imgs/ icdar数据集的训练数据 + └─ ch4_test_images/ icdar数据集的测试数据 + └─ train_icdar2015_label.txt icdar数据集的训练标注 + └─ test_icdar2015_label.txt icdar数据集的测试标注 +``` + +提供的标注文件格式为: +``` +" 图像文件名 json.dumps编码的图像标注信息" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}] +``` +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的$points$表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 +$transcription$表示当前文本框的文字,在文本检测任务中并不需要这个信息。 +如果您想在其他数据集上训练PaddleOCR,可以按照上述形式构建标注文件。 + + +## 快速启动训练 + +首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd, +您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 +``` +cd PaddleOCR/ +# 下载MobileNetV3的预训练模型 +wget -P /PaddleOCR/pretrain_models/ 模型链接 +# 下载ResNet50的预训练模型 +wget -P /PaddleOCR/pretrain_models/ 模型链接 +``` + +**启动训练** +``` +python3 tools/train.py -c configs/det/det_db_mv3.yml +``` + +上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 +有关配置文件的详细解释,请参考[链接]()。 + +您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 +``` +python3 tools/train.py -c configs/det/det_db_mv3.yml -o Optimizer.base_lr=0.0001 +``` + +## 指标评估 + +PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。 + +运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。 + +``` +python3 tools/eval.py -c configs/det/det_db_mv3.yml -o Gloabl.checkpoints="./output/best_accuracy" +``` + +## 测试检测效果 + +测试单张图像的检测效果 +``` +python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo.jpg" +``` + +测试文件夹下所有图像的检测效果 +``` +python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo_img/" +``` diff --git a/doc/inference.md b/doc/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..d44a96d7e202da4cec60c4cf437c9b42a6a44356 --- /dev/null +++ b/doc/inference.md @@ -0,0 +1,58 @@ + +# 基于inference model的推理 + +inference 模型(fluid.io.save_inference_model保存的模型) +一般是模型训练完成后保存的固化模型,多用于预测部署。 +训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。 +与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越。 + +PaddleOCR提供了将checkpoints转换成inference model的实现。 + + +## 文本检测模型推理 + +将文本检测模型训练过程中保存的模型,转换成inference model,可以使用如下命令: + +``` +python tools/export_model.py -c configs/det/det_db_mv3.yml -o Global.checkpoints="./output/best_accuracy" \ + Global.save_inference_dir="./inference/det/" +``` + +推理模型保存在$./inference/det/model$, $./inference/det/params$ + +使用保存的inference model实现在单张图像上的预测: + +``` +python tools/infer/predict_det.py --image_dir="/demo.jpg" --det_model_dir="./inference/det/" +``` + + +## 文本识别模型推理 + +将文本识别模型训练过程中保存的模型,转换成inference model,可以使用如下命令: + +``` +python tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints="./output/best_accuracy" \ + Global.save_inference_dir="./inference/rec/" +``` + +推理模型保存在$./inference/rec/model$, $./inference/rec/params$ + +使用保存的inference model实现在单张图像上的预测: + +``` +python tools/infer/predict_rec.py --image_dir="/demo.jpg" --rec_model_dir="./inference/rec/" +``` + +## 文本检测、识别串联推理 + +实现文本检测、识别串联推理,预测$image_dir$指定的单张图像: +``` +python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` + +实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像: + +``` +python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +``` diff --git a/doc/installation.md b/doc/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..25e0d0d56cc502f355bec5f3cf2bcb34cb1fe2d0 --- /dev/null +++ b/doc/installation.md @@ -0,0 +1,27 @@ +## 快速安装 + +建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/)。 +1. 准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。 +``` +# 切换到工作目录下 +cd /home/Projects +# 首次运行需创建一个docker容器,再次运行时不需要运行当前命令 +# 创建一个名字为pdocr的docker容器,并将当前目录映射到容器的/data目录下 +sudo nvidia-docker run --name pdocr -v $PWD:/data --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash + +# ctrl+P+Q可退出docker,重新进入docker使用如下命令 +sudo nvidia-docker container exec -it pdocr /bin/bash + +``` + +2. 克隆PaddleOCR repo代码 +``` +git clone https://github.com/PaddlePaddle/PaddleOCR +``` + +3. 安装第三方库 +``` +cd PaddleOCR +pip3 install --upgrade pip +pip3 install -r requirements.txt +``` diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 5ba01ee790fb07108cf4717817b25770785214d7..0feedeeb65da67ec887425344ddbe0a59b528719 100755 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -22,6 +22,7 @@ import string from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.utility import create_module +from ppocr.utils.utility import get_image_file_list import time @@ -72,16 +73,8 @@ class EvalTestReader(object): self.params) batch_size = self.params['test_batch_size_per_card'] - flag_test_single_img = False - if mode == "test": - single_img_path = self.params['single_img_path'] - if single_img_path is not None: - flag_test_single_img = True - img_list = [] - if flag_test_single_img: - img_list.append([single_img_path, single_img_path]) - else: + if mode != "test": img_set_dir = self.params['img_set_dir'] img_name_list_path = self.params['label_file_path'] with open(img_name_list_path, "rb") as fin: @@ -90,6 +83,9 @@ class EvalTestReader(object): img_name = line.decode().strip("\n").split("\t")[0] img_path = img_set_dir + "/" + img_name img_list.append([img_path, img_name]) + else: + img_path = self.params['single_img_path'] + img_list = get_image_file_list(img_path) def batch_iter_reader(): batch_outs = [] diff --git a/ppocr/data/det/db_process.py b/ppocr/data/det/db_process.py index 2a6393a166952311e08d1352d0b39b03dd273311..faca9ac20967623f62d64c5394d1ed397b55cbf1 100644 --- a/ppocr/data/det/db_process.py +++ b/ppocr/data/det/db_process.py @@ -124,9 +124,6 @@ class DBProcessTest(object): def resize_image_type0(self, im): """ resize image to a size multiple of 32 which is required by the network - :param im: the resized image - :param max_side_len: limit of max image size to avoid out of memory in gpu - :return: the resized image and the resize ratio """ max_side_len = self.max_side_len h, w, _ = im.shape diff --git a/ppocr/data/reader_main.py b/ppocr/data/reader_main.py index 323620bc5997d545a8c04b94a3097ce5060213cc..55bd1e0842558635533b6bf2d746a3ad8a7c5b9d 100755 --- a/ppocr/data/reader_main.py +++ b/ppocr/data/reader_main.py @@ -73,9 +73,3 @@ def reader_main(config=None, mode=None): return paddle.reader.multiprocess_reader(readers, False) else: return function(mode) - - -def test_reader(image_shape, img_path): - img = cv2.imread(img_path) - norm_img = process_image(img, image_shape) - return norm_img diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 6a81465c1fcc096afdd87627cabd428bf15df5a8..c824d4404ad5f1fbee0d672fbc06d365f7d4eade 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os def initial_logger(): @@ -55,6 +56,23 @@ def get_check_reader_params(mode): return check_params +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists + + from paddle import fluid diff --git a/requirments.txt b/requirments.txt new file mode 100644 index 0000000000000000000000000000000000000000..94e8478ffad88a6e5cd69424c6aa485400cfae06 --- /dev/null +++ b/requirments.txt @@ -0,0 +1,4 @@ +shapely +imgaug +pyclipper +lmdb \ No newline at end of file diff --git a/tools/eval_utils/eval_det_iou.py b/tools/eval_utils/eval_det_iou.py index c6dacb3e1b27202c9e3bbac1cce4bbbf871b9a7b..2f5ff2f10f1d5ecd1d33d2c742633945288ffe4c 100644 --- a/tools/eval_utils/eval_det_iou.py +++ b/tools/eval_utils/eval_det_iou.py @@ -3,6 +3,10 @@ from collections import namedtuple import numpy as np from shapely.geometry import Polygon +""" +reference from : +https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8 +""" class DetectionIoUEvaluator(object): diff --git a/tools/eval_utils/eval_det_utils.py b/tools/eval_utils/eval_det_utils.py index 015cba99fbdb4fcd1e1daf83f6e6216d939a14c2..f0be714fc04f4afe2abf5be6f6e5fcb3a9803a66 100644 --- a/tools/eval_utils/eval_det_utils.py +++ b/tools/eval_utils/eval_det_utils.py @@ -98,6 +98,14 @@ def load_label_infor(label_file_path, do_ignore=False): def cal_det_metrics(gt_label_path, save_res_path): + """ + calculate the detection metrics + Args: + gt_label_path(string): The groundtruth detection label file path + save_res_path(string): The saved predicted detection label path + return: + claculated metrics including Hmean、precision and recall + """ evaluator = DetectionIoUEvaluator() gt_label_infor = load_label_infor(gt_label_path, do_ignore=True) dt_label_infor = load_label_infor(save_res_path) diff --git a/tools/export_model.py b/tools/export_model.py index 6c924f3fb227b3d9d47c2f7547b6f6b9f962f5f9..11a2744b399d6db2e1b07d258062ad1a47548ffb 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -71,14 +71,19 @@ def main(): init_model(config, eval_program, exe) + save_inference_dir = config['Global']['save_inference_dir'] + if not os.path.exists(save_inference_dir): + os.makedirs(save_inference_dir) fluid.io.save_inference_model( - dirname="./output/", + dirname=save_inference_dir, feeded_var_names=feeded_var_names, main_program=eval_program, target_vars=target_vars, executor=exe, model_filename='model', params_filename='params') + print("inference model saved in {}/model and {}/params".format( + save_inference_dir, save_inference_dir)) print("save success, output_name_list:", fetches_var_name) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index defa0615958b357800e35c3999d95cc98abbe34d..4907a7ccf0105b076e702e169af519b02091d654 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -64,15 +64,39 @@ class TextSystem(object): if dt_boxes is None: return None, None img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) rec_res, elapse = self.text_recognizer(img_crop_list) - # self.print_draw_crop_rec_res(img_crop_list, rec_res) + # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1]) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + if __name__ == "__main__": args = utility.parse_args() image_file_list = utility.get_image_file_list(args.image_dir) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a4f9f03daa8cc36e11551bc959ab487c2472fc84..6d56a99080c9fcc7dc5feade251956b28b965475 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -61,18 +61,6 @@ def parse_args(): return parser.parse_args() -def get_image_file_list(image_dir): - image_file_list = [] - if image_dir is None: - return image_file_list - if os.path.isfile(image_dir): - image_file_list = [image_dir] - elif os.path.isdir(image_dir): - for single_file in os.listdir(image_dir): - image_file_list.append(os.path.join(image_dir, single_file)) - return image_file_list - - def create_predictor(args, mode): if mode == "det": model_dir = args.det_model_dir @@ -99,14 +87,7 @@ def create_predictor(args, mode): config.disable_gpu() config.disable_glog_info() - config.switch_ir_optim(args.ir_optim) - # if args.use_tensorrt: - # config.enable_tensorrt_engine( - # precision_mode=AnalysisConfig.Precision.Half - # if args.use_fp16 else AnalysisConfig.Precision.Float32, - # max_batch_size=args.batch_size) - - config.enable_memory_optim() + # use zero copy config.switch_use_feed_fetch_ops(False) predictor = create_paddle_predictor(config) @@ -127,21 +108,3 @@ def draw_text_det_res(dt_boxes, img_path): cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) img_name_pure = img_path.split("/")[-1] cv2.imwrite("./output/%s" % img_name_pure, src_im) - - -if __name__ == '__main__': - args = parse_args() - args.use_gpu = False - root_path = "/Users/liuweiwei06/Desktop/TEST_CODES/icode/baidu/personal-code/PaddleOCR/" - args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db" - - predictor, input_tensor, output_tensors = create_predictor(args, mode='det') - print(predictor.get_input_names()) - print(predictor.get_output_names()) - print(predictor.program(), file=open("det_program.txt", 'w')) - - args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/" - rec_predictor, input_tensor, output_tensors = create_predictor( - args, mode='rec') - print(rec_predictor.get_input_names()) - print(rec_predictor.get_output_names()) diff --git a/tools/infer_det.py b/tools/infer_det.py new file mode 100755 index 0000000000000000000000000000000000000000..8d591a654d4b461ab07a81e78cbd046b753f0e96 --- /dev/null +++ b/tools/infer_det.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import time +import numpy as np +from copy import deepcopy +import json + +# from paddle.fluid.contrib.model_stat import summary + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +from paddle import fluid +from ppocr.utils.utility import create_module +import program +from ppocr.utils.save_load import init_model +from ppocr.data.reader_main import reader_main +import cv2 + +from ppocr.utils.utility import initial_logger +logger = initial_logger() + + +def draw_det_res(dt_boxes, config, img_name, ino): + if len(dt_boxes) > 0: + img_set_path = config['TestReader']['img_set_dir'] + img_path = img_set_path + img_name + import cv2 + src_im = cv2.imread(img_path) + for box in dt_boxes: + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + save_det_path = os.path.basename(config['Global'][ + 'save_res_path']) + "/det_results/" + if not os.path.exists(save_det_path): + os.makedirs(save_det_path) + save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name)) + cv2.imwrite(save_path, src_im) + logger.info("The detected Image saved in {}".format(save_path)) + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + print(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + program.check_gpu(use_gpu) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + det_model = create_module(config['Architecture']['function'])(params=config) + + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + _, eval_outputs = det_model(mode="test") + fetch_name_list = list(eval_outputs.keys()) + eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list] + + eval_prog = eval_prog.clone(for_test=True) + exe.run(startup_prog) + + # load checkpoints + checkpoints = config['Global'].get('checkpoints') + if checkpoints: + path = checkpoints + fluid.load(eval_prog, path, exe) + logger.info("Finish initing model from {}".format(path)) + else: + raise Exception("{} not exists!".format(checkpoints)) + + save_res_path = config['Global']['save_res_path'] + with open(save_res_path, "wb") as fout: + test_reader = reader_main(config=config, mode='test') + tackling_num = 0 + for data in test_reader(): + img_num = len(data) + tackling_num = tackling_num + img_num + logger.info("tackling_num:%d", tackling_num) + img_list = [] + ratio_list = [] + img_name_list = [] + for ino in range(img_num): + img_list.append(data[ino][0]) + ratio_list.append(data[ino][1]) + img_name_list.append(data[ino][2]) + + img_list = np.concatenate(img_list, axis=0) + outs = exe.run(eval_prog,\ + feed={'image': img_list},\ + fetch_list=eval_fetch_list) + + global_params = config['Global'] + postprocess_params = deepcopy(config["PostProcess"]) + postprocess_params.update(global_params) + postprocess = create_module(postprocess_params['function'])\ + (params=postprocess_params) + dt_boxes_list = postprocess({"maps": outs[0]}, ratio_list) + for ino in range(img_num): + dt_boxes = dt_boxes_list[ino] + img_name = img_name_list[ino] + dt_boxes_json = [] + for box in dt_boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_json.append(tmp_json) + otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + draw_det_res(dt_boxes, config, img_name, ino) + + logger.info("success!") + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() diff --git a/tools/program.py b/tools/program.py index a34e56ca00fa69500f780b287b5aa55244658f40..f74aacc738df09c7fd2d9a4bec9189030fa571ed 100755 --- a/tools/program.py +++ b/tools/program.py @@ -185,22 +185,6 @@ def build(config, main_prog, startup_prog, mode): def build_export(config, main_prog, startup_prog): """ - Build a program using a model and an optimizer - 1. create feeds - 2. create a dataloader - 3. create a model - 4. create fetchs - 5. create an optimizer - - Args: - config(dict): config - main_prog(): main program - startup_prog(): startup program - is_train(bool): train or valid - - Returns: - dataloader(): a bridge between the model and the data - fetchs(dict): dict of model outputs(included loss and measures) """ with fluid.program_guard(main_prog, startup_prog): with fluid.unique_name.guard():