diff --git a/README.md b/README.md index 1c15ad256c666c7880242c4649a98c341fc0136c..42ec76d5b1e4cdd7106ea4a3a8a6108af5293588 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,9 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_mode # 预测image_dir指定的图像集合 python3 tools/infer/predict_system.py --image_dir="./doc/imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" + +# 如果想使用CPU进行预测,执行命令如下 +python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" --use_gpu=False ``` 更多的文本检测、识别串联推理使用方式请参考文档教程中[基于预测引擎推理](./doc/inference.md)。 diff --git a/doc/imgs_en/img_10.jpg b/doc/imgs_en/img_10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0055513aef3bb0a3697081d89c752663d9ef17b Binary files /dev/null and b/doc/imgs_en/img_10.jpg differ diff --git a/doc/imgs_en/img_11.jpg b/doc/imgs_en/img_11.jpg new file mode 100644 index 0000000000000000000000000000000000000000..34397beb6b1ca6bc351efa5d2c05661a2e49f681 Binary files /dev/null and b/doc/imgs_en/img_11.jpg differ diff --git a/doc/imgs_en/img_195.jpg b/doc/imgs_en/img_195.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5d5546caddc8e3aa115d4895779e36e00eb91da1 Binary files /dev/null and b/doc/imgs_en/img_195.jpg differ diff --git a/doc/imgs_results/det_res_2.jpg b/doc/imgs_results/det_res_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aebcd8ccaca02db7ed4a09cd63ade422abc4735f Binary files /dev/null and b/doc/imgs_results/det_res_2.jpg differ diff --git a/doc/imgs_results/det_res_img_10_db.jpg b/doc/imgs_results/det_res_img_10_db.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bde1585cb50137ae1fd33ce7edfa59e7224ddc96 Binary files /dev/null and b/doc/imgs_results/det_res_img_10_db.jpg differ diff --git a/doc/imgs_results/det_res_img_10_east.jpg b/doc/imgs_results/det_res_img_10_east.jpg new file mode 100644 index 0000000000000000000000000000000000000000..400b9e6fcd75f886fc99964cd0793e2ff07693d2 Binary files /dev/null and b/doc/imgs_results/det_res_img_10_east.jpg differ diff --git a/doc/imgs_results/img_10.jpg b/doc/imgs_results/img_10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41da69babf22aebcc7671afb678626fc97c2a21d Binary files /dev/null and b/doc/imgs_results/img_10.jpg differ diff --git a/doc/imgs_words_en/.DS_Store b/doc/imgs_words_en/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/doc/imgs_words_en/.DS_Store differ diff --git a/doc/imgs_words_en/word_10.png b/doc/imgs_words_en/word_10.png new file mode 100644 index 0000000000000000000000000000000000000000..07370f757ea83d5e3c5b1f7498b1f95d3aec2d18 Binary files /dev/null and b/doc/imgs_words_en/word_10.png differ diff --git a/doc/imgs_words_en/word_116.png b/doc/imgs_words_en/word_116.png new file mode 100644 index 0000000000000000000000000000000000000000..fd000ff60b0a7a87c8c16eb1409aa0548228e3f9 Binary files /dev/null and b/doc/imgs_words_en/word_116.png differ diff --git a/doc/imgs_words_en/word_19.png b/doc/imgs_words_en/word_19.png new file mode 100644 index 0000000000000000000000000000000000000000..d2a2859d02fd5475c7f48b57d5d607eece7cb258 Binary files /dev/null and b/doc/imgs_words_en/word_19.png differ diff --git a/doc/imgs_words_en/word_201.png b/doc/imgs_words_en/word_201.png new file mode 100644 index 0000000000000000000000000000000000000000..99abd7330998d097e0ce454c50393cc72c6ad24b Binary files /dev/null and b/doc/imgs_words_en/word_201.png differ diff --git a/doc/imgs_words_en/word_308.png b/doc/imgs_words_en/word_308.png new file mode 100644 index 0000000000000000000000000000000000000000..a8d094faff083217352524c6db2f6844fd9f6ab3 Binary files /dev/null and b/doc/imgs_words_en/word_308.png differ diff --git a/doc/imgs_words_en/word_336.png b/doc/imgs_words_en/word_336.png new file mode 100644 index 0000000000000000000000000000000000000000..3bddd294ed7a04fc9b8e55d603c38734c6134753 Binary files /dev/null and b/doc/imgs_words_en/word_336.png differ diff --git a/doc/imgs_words_en/word_401.png b/doc/imgs_words_en/word_401.png new file mode 100644 index 0000000000000000000000000000000000000000..0a4ee6935f80192e734f94a7ac07233c5db5de40 Binary files /dev/null and b/doc/imgs_words_en/word_401.png differ diff --git a/doc/imgs_words_en/word_461.png b/doc/imgs_words_en/word_461.png new file mode 100644 index 0000000000000000000000000000000000000000..a73e5c494ba280726539fe6a00c97c5968061b63 Binary files /dev/null and b/doc/imgs_words_en/word_461.png differ diff --git a/doc/imgs_words_en/word_52.png b/doc/imgs_words_en/word_52.png new file mode 100644 index 0000000000000000000000000000000000000000..493c5901835ba43dd4fd1a0f749d5f01622b0627 Binary files /dev/null and b/doc/imgs_words_en/word_52.png differ diff --git a/doc/imgs_words_en/word_545.png b/doc/imgs_words_en/word_545.png new file mode 100644 index 0000000000000000000000000000000000000000..5d4a2a7deaf929a35a41f01746656fe1e3e1b585 Binary files /dev/null and b/doc/imgs_words_en/word_545.png differ diff --git a/doc/inference.md b/doc/inference.md index 5fadef389bf7763d44962ec2b57492bd22999bd7..a9f2d590611955c03d5a82be29fcc73fa3c4cedf 100644 --- a/doc/inference.md +++ b/doc/inference.md @@ -6,51 +6,155 @@ inference 模型(fluid.io.save_inference_model保存的模型) 训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。更详细的介绍请参考文档[分类预测框架](https://paddleclas.readthedocs.io/zh_CN/latest/extension/paddle_inference.html). 接下来将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。与此同时也会介绍checkpoints转换成inference model的实现。 - ## 文本检测模型推理 -将文本检测模型训练过程中保存的模型,转换成inference model,可以使用如下命令: +下面将介绍超轻量中文检测模型推理、DB文本检测模型推理和EAST文本检测模型推理。默认配置是根据DB文本检测模型推理设置的。由于EAST和DB算法差别很大,在推理时,需要通过传入相应的参数适配EAST文本检测算法。 + +### 1.超轻量中文检测模型推理 + +超轻量中文检测模型推理,可以执行如下命令: ``` -python tools/export_model.py -c configs/det/det_db_mv3.yml -o Global.checkpoints="./output/best_accuracy" \ - Global.save_inference_dir="./inference/det/" +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det/" ``` -推理模型保存在$./inference/det/model$, $./inference/det/params$ +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_2.jpg) -使用保存的inference model实现在单张图像上的预测: +通过设置参数det_max_side_len的大小,改变检测算法中图片规范化的最大值。当图片的长宽都小于det_max_side_len,则使用原图预测,否则将图片等比例缩放到最大值,进行预测。该参数默认设置为det_max_side_len=960. 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令: ``` -python tools/infer/predict_det.py --image_dir="/demo.jpg" --det_model_dir="./inference/det/" +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det/" --det_max_side_len=1200 +``` + +### 2.DB文本检测模型推理 + +首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换: + ``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db" +``` + +DB文本检测模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" +``` + +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_img_10_db.jpg) + +**注意**:由于ICDAR2015数据集只有1000张训练图像,主要针对英文场景,所以上述模型对中文文本图像检测效果非常差。 + +### 3.EAST文本检测模型推理 + +首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east" +``` + +EAST文本检测模型推理,需要设置参数det_algorithm,指定检测算法类型为EAST,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" +``` +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_img_10_east.jpg) + +**注意**:本代码库中EAST后处理中NMS采用的Python版本,所以预测速度比较耗时。如果采用C++版本,会有明显加速。 ## 文本识别模型推理 -将文本识别模型训练过程中保存的模型,转换成inference model,可以使用如下命令: +下面将介绍超轻量中文检测模型推理和基于CTC损失的识别模型推理。**而基于Attention损失的识别模型推理还在调试中**。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。 + + +### 1.超轻量中文识别模型推理 + +超轻量中文识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/word_4.jpg" --rec_model_dir="./inference/rec/" +``` + +![](imgs_words/word_4.jpg) + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: + +Predicts of ./doc/imgs_words/word_4.jpg:['实力活力', 0.9504319] + + +### 2.基于CTC损失的识别模型推理 + +我们以STAR-Net为例,介绍基于CTC损失的识别模型推理。 CRNN和Rosetta使用方式类似,不用设置识别算法参数rec_algorithm。 + +首先将STAR-Net文本识别训练过程中保存的模型,转换成inference model。以基于Resnet34_vd骨干网络,使用MJSynth和SynthText两个英文文本识别合成数据集训练 +的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.checkpoints="./models/rec_r34_vd_tps_bilstm_ctc/best_accuracy" Global.save_inference_dir="./inference/starnet" +``` + +STAR-Net文本识别模型推理,可以执行如下命令: ``` -python tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints="./output/best_accuracy" \ - Global.save_inference_dir="./inference/rec/" +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_401.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" ``` +![](imgs_words_en/word_401.png) + +执行命令后,上面图像的识别结果如下: + +Predicts of ./doc/imgs_words_en/word_401.png:['burgen', 0.9008867] + +**注意**:由于上述模型是参考[DTRB](https://arxiv.org/abs/1904.01906)文本识别训练和评估流程,与超轻量级中文识别模型训练有两方面不同: -推理模型保存在$./inference/rec/model$, $./inference/rec/params$ +- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[3,32,100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此,这里推理上述英文模型时,需要通过参数rec_image_shape设置识别图像的形状。 -使用保存的inference model实现在单张图像上的预测: +- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type,指定为英文"en"。 ``` -python tools/infer/predict_rec.py --image_dir="/demo.jpg" --rec_model_dir="./inference/rec/" +self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" +dict_character = list(self.character_str) ``` ## 文本检测、识别串联推理 -实现文本检测、识别串联推理,预测$image_dir$指定的单张图像: +### 1.超轻量中文OCR模型推理 + +在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。 + ``` -python tools/infer/predict_eval.py --image_dir="/Demo.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" ``` -实现文本检测、识别串联推理,预测$image_dir$指指定文件夹下的所有图像: +执行命令后,识别结果图像如下: + +![](imgs_results/2.jpg) + +### 2.其他模型推理 + +如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令: ``` -python tools/infer/predict_eval.py --image_dir="/test_imgs/" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" +python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/rec/" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" ``` + +执行命令后,识别结果图像如下: + +![](imgs_results/img_10.jpg) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 2a446cdff2e6231881d8d4334fc4342cdd4bb512..82fed48feec063fc4379bc469372f876aabaad6f 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -15,19 +15,17 @@ import utility from ppocr.utils.utility import initial_logger logger = initial_logger() +from ppocr.utils.utility import get_image_file_list import cv2 from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.db_process import DBProcessTest from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.east_postprocess import EASTPostPocess -from ppocr.utils.utility import get_image_file_list -from tools.infer.utility import draw_ocr import copy import numpy as np import math import time import sys -import os class TextDetector(object): @@ -79,27 +77,10 @@ class TextDetector(object): rect = np.array([tl, tr, br, bl], dtype="float32") return rect - def expand_det_res(self, points, bbox_height, bbox_width, img_height, - img_width): - if bbox_height * 1.0 / bbox_width >= 2.0: - expand_w = bbox_width * 0.20 - expand_h = bbox_width * 0.20 - elif bbox_width * 1.0 / bbox_height >= 3.0: - expand_w = bbox_height * 0.20 - expand_h = bbox_height * 0.20 - else: - expand_w = bbox_height * 0.1 - expand_h = bbox_height * 0.1 - - points[0, 0] = int(max((points[0, 0] - expand_w), 0)) - points[1, 0] = int(min((points[1, 0] + expand_w), img_width)) - points[3, 0] = int(max((points[3, 0] - expand_w), 0)) - points[2, 0] = int(min((points[2, 0] + expand_w), img_width)) - - points[0, 1] = int(max((points[0, 1] - expand_h), 0)) - points[1, 1] = int(max((points[1, 1] - expand_h), 0)) - points[3, 1] = int(min((points[3, 1] + expand_h), img_height)) - points[2, 1] = int(min((points[2, 1] + expand_h), img_height)) + def clip_det_res(self, points, img_height, img_width): + for pno in range(4): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points def filter_tag_det_res(self, dt_boxes, image_shape): @@ -107,22 +88,11 @@ class TextDetector(object): dt_boxes_new = [] for box in dt_boxes: box = self.order_points_clockwise(box) - left = int(np.min(box[:, 0])) - right = int(np.max(box[:, 0])) - top = int(np.min(box[:, 1])) - bottom = int(np.max(box[:, 1])) - bbox_height = bottom - top - bbox_width = right - left - diffh = math.fabs(box[0, 1] - box[1, 1]) - diffw = math.fabs(box[0, 0] - box[3, 0]) + box = self.clip_det_res(box, img_height, img_width) rect_width = int(np.linalg.norm(box[0] - box[1])) rect_height = int(np.linalg.norm(box[0] - box[3])) if rect_width <= 10 or rect_height <= 10: continue - # if diffh <= 10 and diffw <= 10: - # box = self.expand_det_res( - # copy.deepcopy(box), bbox_height, bbox_width, img_height, - # img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes @@ -153,8 +123,6 @@ class TextDetector(object): return dt_boxes, elapse -from tools.infer.utility import draw_text_det_res - if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) @@ -171,9 +139,8 @@ if __name__ == "__main__": total_time += elapse count += 1 print("Predict time of %s:" % image_file, elapse) - img_draw = draw_text_det_res(dt_boxes, image_file, return_img=True) - save_path = os.path.join("./inference_det/", - os.path.basename(image_file)) - print("The visualized image saved in {}".format(save_path)) - - print("Avg Time:", total_time / (count - 1)) + src_im = utility.draw_text_det_res(dt_boxes, image_file) + img_name_pure = image_file.split("/")[-1] + cv2.imwrite("./inference_results/det_res_%s" % img_name_pure, src_im) + if count > 1: + print("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 0681cfb5510b5a4c3869d4c815ec9a029554578a..c8c0797b499055ec681d0362a054a30d7322b65a 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -15,8 +15,8 @@ import utility from ppocr.utils.utility import initial_logger logger = initial_logger() +from ppocr.utils.utility import get_image_file_list import cv2 - import copy import numpy as np import math @@ -30,6 +30,7 @@ class TextRecognizer(object): utility.create_predictor(args, mode="rec") image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_image_shape = image_shape + self.character_type = args.rec_char_type char_ops_params = {} char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_dict_path"] = args.rec_char_dict_path @@ -38,7 +39,8 @@ class TextRecognizer(object): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape - imgW = int(32 * max_wh_ratio) + if self.character_type == "ch": + imgW = int(32 * max_wh_ratio) h = img.shape[0] w = img.shape[1] ratio = w / float(h) @@ -102,7 +104,7 @@ class TextRecognizer(object): if __name__ == "__main__": args = utility.parse_args() - image_file_list = utility.get_image_file_list(args.image_dir) + image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] @@ -114,6 +116,7 @@ if __name__ == "__main__": valid_image_file_list.append(image_file) img_list.append(img) rec_res, predict_time = text_recognizer(img_list) + rec_res, predict_time = text_recognizer(img_list) for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % diff --git a/tools/program.py b/tools/program.py index b6318c3bbf99c670752de9c855d8bdbdd16bf09b..a114b1cbffa21358202c657caf90a107609ff9d1 100755 --- a/tools/program.py +++ b/tools/program.py @@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) image, outputs = model(mode='export') - fetches_var = sorted([outputs[name] for name in outputs]) - fetches_var_name = [name for name in fetches_var] + fetches_var_name = sorted([name for name in outputs]) + fetches_var = [outputs[name] for name in fetches_var_name] feeded_var_names = [image.name] target_vars = fetches_var return feeded_var_names, target_vars, fetches_var_name