diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md old mode 100644 new mode 100755 index 663533c492ab5dc0bd22cc79bd95c9d1d194d854..10c01666404c1a66a14485ed30195954ed881b6f --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -173,7 +173,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ ### 3. EAST文本检测模型推理 -首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换: +首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east @@ -186,7 +186,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img_10_east.jpg) **注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 @@ -194,7 +194,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ### 4. SAST文本检测模型推理 #### (1). 四边形文本检测模型(ICDAR2015) -首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: +首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 @@ -205,10 +205,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img_10_sast.jpg) #### (2). 弯曲文本检测模型(Total-Text) -首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: +首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt @@ -221,7 +221,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img623_sast.jpg) **注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md old mode 100644 new mode 100755 index 411a733dd062cf347d7a2e5d5d067739bda36819..606565275ba243969bf919bfb91a0a2067a7a8cd --- a/doc/doc_en/inference_en.md +++ b/doc/doc_en/inference_en.md @@ -179,7 +179,7 @@ The visualized text detection results are saved to the `./inference_results` fol ### 3. EAST TEXT DETECTION MODEL INFERENCE -First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: +First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east @@ -192,7 +192,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img_10_east.jpg) **Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. @@ -200,7 +200,7 @@ The visualized text detection results are saved to the `./inference_results` fol ### 4. SAST TEXT DETECTION MODEL INFERENCE #### (1). Quadrangle text detection model (ICDAR2015) -First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: +First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 @@ -214,10 +214,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img_10_sast.jpg) #### (2). Curved text detection model (Total-Text) -First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert: +First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt @@ -231,7 +231,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img623_sast.jpg) **Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. diff --git a/doc/imgs_results/det_res_img623_sast.jpg b/doc/imgs_results/det_res_img623_sast.jpg index b2dd538f7729724a33516091d11c081c7c2c1bd7..af5e2d6e2c5643ee71a29bc015e8ae88a8d20058 100644 Binary files a/doc/imgs_results/det_res_img623_sast.jpg and b/doc/imgs_results/det_res_img623_sast.jpg differ diff --git a/doc/imgs_results/det_res_img_10_east.jpg b/doc/imgs_results/det_res_img_10_east.jpg index 400b9e6fcd75f886fc99964cd0793e2ff07693d2..908d077c3eabcb95eabf4afc54ce0bed1b54f355 100644 Binary files a/doc/imgs_results/det_res_img_10_east.jpg and b/doc/imgs_results/det_res_img_10_east.jpg differ diff --git a/doc/imgs_results/det_res_img_10_sast.jpg b/doc/imgs_results/det_res_img_10_sast.jpg index c63faf1354601f25cedb57a3b87f4467999f5457..702f773e68fe339e9acbc4d21c98cd0aa4536ef5 100644 Binary files a/doc/imgs_results/det_res_img_10_sast.jpg and b/doc/imgs_results/det_res_img_10_sast.jpg differ diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 16c789dcd7e9740ca8ddf613d0f2567c9af22820..b2deb3dc29613def43533d412872b65c9f589cce 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -132,7 +132,8 @@ class DBPostProcess(object): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, pred, shape_list): + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] if isinstance(pred, paddle.Tensor): pred = pred.numpy() pred = pred[:, 0, :, :] diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py old mode 100644 new mode 100755 index 0b669405562aef9812b9771977bf82f362beb75e..ceee727aa3df052041aee925c6c856773c8a288e --- a/ppocr/postprocess/east_postprocess.py +++ b/ppocr/postprocess/east_postprocess.py @@ -19,12 +19,10 @@ from __future__ import print_function import numpy as np from .locality_aware_nms import nms_locality import cv2 +import paddle import os import sys -# __dir__ = os.path.dirname(os.path.abspath(__file__)) -# sys.path.append(__dir__) -# sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) class EASTPostProcess(object): @@ -113,11 +111,14 @@ class EASTPostProcess(object): def __call__(self, outs_dict, shape_list): score_list = outs_dict['f_score'] geo_list = outs_dict['f_geo'] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + geo_list = geo_list.numpy() img_num = len(shape_list) dt_boxes_list = [] for ino in range(img_num): - score = score_list[ino].numpy() - geo = geo_list[ino].numpy() + score = score_list[ino] + geo = geo_list[ino] boxes = self.detect( score_map=score, geo_map=geo, diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py old mode 100644 new mode 100755 index 03b0e8f17d60a8863a2d1d5900e01dcd4874d5a9..f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1 --- a/ppocr/postprocess/sast_postprocess.py +++ b/ppocr/postprocess/sast_postprocess.py @@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..')) import numpy as np from .locality_aware_nms import nms_locality -# import lanms +import paddle import cv2 import time @@ -276,14 +276,19 @@ class SASTPostProcess(object): border_list = outs_dict['f_border'] tvo_list = outs_dict['f_tvo'] tco_list = outs_dict['f_tco'] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + border_list = border_list.numpy() + tvo_list = tvo_list.numpy() + tco_list = tco_list.numpy() img_num = len(shape_list) poly_lists = [] for ino in range(img_num): - p_score = score_list[ino].transpose((1,2,0)).numpy() - p_border = border_list[ino].transpose((1,2,0)).numpy() - p_tvo = tvo_list[ino].transpose((1,2,0)).numpy() - p_tco = tco_list[ino].transpose((1,2,0)).numpy() + p_score = score_list[ino].transpose((1,2,0)) + p_border = border_list[ino].transpose((1,2,0)) + p_tvo = tvo_list[ino].transpose((1,2,0)) + p_tco = tco_list[ino].transpose((1,2,0)) src_h, src_w, ratio_h, ratio_w = shape_list[ino] poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 6f98ded8295dabbd5edf05913245e5d94d856689..f07655a86ba637d84e1dc709d94a61c509929c00 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -37,33 +37,51 @@ class TextDetector(object): def __init__(self, args): self.det_algorithm = args.det_algorithm self.use_zero_copy_run = args.use_zero_copy_run + pre_process_list = [{ + 'DetResizeForTest': { + 'limit_side_len': args.det_limit_side_len, + 'limit_type': args.det_limit_type + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] postprocess_params = {} if self.det_algorithm == "DB": - pre_process_list = [{ - 'DetResizeForTest': { - 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type - } - }, { - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' - } - }, { - 'ToCHWImage': None - }, { - 'KeepKeys': { - 'keep_keys': ['image', 'shape'] - } - }] postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = True + elif self.det_algorithm == "EAST": + postprocess_params['name'] = 'EASTPostProcess' + postprocess_params["score_thresh"] = args.det_east_score_thresh + postprocess_params["cover_thresh"] = args.det_east_cover_thresh + postprocess_params["nms_thresh"] = args.det_east_nms_thresh + elif self.det_algorithm == "SAST": + postprocess_params['name'] = 'SASTPostProcess' + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: + postprocess_params["sample_pts_num"] = 6 + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["sample_pts_num"] = 2 + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -149,12 +167,25 @@ class TextDetector(object): for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) - preds = outputs[0] - # preds = self.predictor(img) + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs[0] + preds['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + preds['f_border'] = outputs[0] + preds['f_score'] = outputs[1] + preds['f_tco'] = outputs[2] + preds['f_tvo'] = outputs[3] + else: + preds['maps'] = outputs[0] + post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime return dt_boxes, elapse