diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md old mode 100644 new mode 100755 index c61a9234bb29d1562112c47b5ca34118c12331d6..997bc1f1978b93e45472b5698341fa91dd9e7e84 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -17,17 +17,17 @@ PaddleOCR开源的文本检测算法列表: |模型|骨干网络|precision|recall|Hmean|下载链接| | --- | --- | --- | --- | --- | --- | -|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[下载链接 (coming soon)](link)| -|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接 (coming soon)](coming soon)| -|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| -|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| -|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接 (coming soon)](link)| +|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)| +|EAST|MobileNetV3|78.24%|79.15%|78.69%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)| +|DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| +|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| +|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))| 在Total-text文本检测公开数据集上,算法效果如下: |模型|骨干网络|precision|recall|Hmean|下载链接| | --- | --- | --- | --- | --- | --- | -|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[下载链接 (coming soon)](link)| +|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| **说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi) @@ -47,13 +47,10 @@ PaddleOCR基于动态图开源的文本识别算法列表: 参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| -| --- | --- | --- | --- | --- | -|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| +|-|-|-|-|-| |Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)| -|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| +|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| |CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)| -|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接 (coming soon )]()| -|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[下载链接 (coming soon )]()| - +|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md old mode 100644 new mode 100755 index 962b4ae748922ce6067a7ee4e609366ba1f3e01d..aea7ff1de242dec75cae26a2bf3d6838d7559882 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -180,7 +180,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ ### 3. EAST文本检测模型推理 -首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换: +首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east @@ -193,7 +193,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img_10_east.jpg) **注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 @@ -201,7 +201,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ### 4. SAST文本检测模型推理 #### (1). 四边形文本检测模型(ICDAR2015) -首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: +首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 @@ -212,10 +212,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img_10_sast.jpg) #### (2). 弯曲文本检测模型(Total-Text) -首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: +首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt @@ -228,7 +228,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ``` 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: -(coming soon) +![](../imgs_results/det_res_img623_sast.jpg) **注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md old mode 100644 new mode 100755 index 847b9288b96adf8c32b6486cb2798b838745c85f..6e4140e22813c0965009ffdef2eefddcec77489d --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -19,17 +19,17 @@ On the ICDAR2015 dataset, the text detection result is as follows: |Model|Backbone|precision|recall|Hmean|Download link| | --- | --- | --- | --- | --- | --- | -|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[download link (coming soon)](link)| -|EAST|MobileNetV3|81.67%|79.83%|80.74%|[download link (coming soon)](coming soon)| -|DB|ResNet50_vd|83.79%|80.65%|82.19%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| -|DB|MobileNetV3|75.92%|73.18%|74.53%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| -|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[download link (coming soon)](link)| +|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)| +|EAST|MobileNetV3|78.24%|79.15%|78.69%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)| +|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| +|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| +|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))| On Total-Text dataset, the text detection result is as follows: |Model|Backbone|precision|recall|Hmean|Download link| | --- | --- | --- | --- | --- | --- | -|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[download link (coming soon)](link)| +|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| **Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). @@ -48,14 +48,10 @@ PaddleOCR open-source text recognition algorithms list: Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: |Model|Backbone|Avg Accuracy|Module combination|Download link| -| --- | --- | --- | --- | --- | -|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| -|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)| -|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| -|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)| -|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[download link (coming soon )]()| -|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[download link (coming soon )]()| - - +|-|-|-|-|-| +|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)| +|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| +|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)| +|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md) diff --git a/doc/doc_en/benchmark_en.md b/doc/doc_en/benchmark_en.md old mode 100644 new mode 100755 diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md old mode 100644 new mode 100755 index 85dc261c89098b7579336b429d7663cb1c0864f5..db86b109d1a13d00aab833aa31d0279622e7c7f8 --- a/doc/doc_en/inference_en.md +++ b/doc/doc_en/inference_en.md @@ -187,7 +187,7 @@ The visualized text detection results are saved to the `./inference_results` fol ### 3. EAST TEXT DETECTION MODEL INFERENCE -First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: +First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east @@ -200,7 +200,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img_10_east.jpg) **Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. @@ -208,7 +208,7 @@ The visualized text detection results are saved to the `./inference_results` fol ### 4. SAST TEXT DETECTION MODEL INFERENCE #### (1). Quadrangle text detection model (ICDAR2015) -First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: +First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 @@ -222,10 +222,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img_10_sast.jpg) #### (2). Curved text detection model (Total-Text) -First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert: +First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert: ``` python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt @@ -239,7 +239,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: -(coming soon) +![](../imgs_results/det_res_img623_sast.jpg) **Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. diff --git a/doc/imgs_results/det_res_img623_sast.jpg b/doc/imgs_results/det_res_img623_sast.jpg index b2dd538f7729724a33516091d11c081c7c2c1bd7..af5e2d6e2c5643ee71a29bc015e8ae88a8d20058 100644 Binary files a/doc/imgs_results/det_res_img623_sast.jpg and b/doc/imgs_results/det_res_img623_sast.jpg differ diff --git a/doc/imgs_results/det_res_img_10_east.jpg b/doc/imgs_results/det_res_img_10_east.jpg index 400b9e6fcd75f886fc99964cd0793e2ff07693d2..908d077c3eabcb95eabf4afc54ce0bed1b54f355 100644 Binary files a/doc/imgs_results/det_res_img_10_east.jpg and b/doc/imgs_results/det_res_img_10_east.jpg differ diff --git a/doc/imgs_results/det_res_img_10_sast.jpg b/doc/imgs_results/det_res_img_10_sast.jpg index c63faf1354601f25cedb57a3b87f4467999f5457..702f773e68fe339e9acbc4d21c98cd0aa4536ef5 100644 Binary files a/doc/imgs_results/det_res_img_10_sast.jpg and b/doc/imgs_results/det_res_img_10_sast.jpg differ diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py old mode 100644 new mode 100755 index 0b669405562aef9812b9771977bf82f362beb75e..ceee727aa3df052041aee925c6c856773c8a288e --- a/ppocr/postprocess/east_postprocess.py +++ b/ppocr/postprocess/east_postprocess.py @@ -19,12 +19,10 @@ from __future__ import print_function import numpy as np from .locality_aware_nms import nms_locality import cv2 +import paddle import os import sys -# __dir__ = os.path.dirname(os.path.abspath(__file__)) -# sys.path.append(__dir__) -# sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) class EASTPostProcess(object): @@ -113,11 +111,14 @@ class EASTPostProcess(object): def __call__(self, outs_dict, shape_list): score_list = outs_dict['f_score'] geo_list = outs_dict['f_geo'] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + geo_list = geo_list.numpy() img_num = len(shape_list) dt_boxes_list = [] for ino in range(img_num): - score = score_list[ino].numpy() - geo = geo_list[ino].numpy() + score = score_list[ino] + geo = geo_list[ino] boxes = self.detect( score_map=score, geo_map=geo, diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py old mode 100644 new mode 100755 index 03b0e8f17d60a8863a2d1d5900e01dcd4874d5a9..f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1 --- a/ppocr/postprocess/sast_postprocess.py +++ b/ppocr/postprocess/sast_postprocess.py @@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..')) import numpy as np from .locality_aware_nms import nms_locality -# import lanms +import paddle import cv2 import time @@ -276,14 +276,19 @@ class SASTPostProcess(object): border_list = outs_dict['f_border'] tvo_list = outs_dict['f_tvo'] tco_list = outs_dict['f_tco'] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + border_list = border_list.numpy() + tvo_list = tvo_list.numpy() + tco_list = tco_list.numpy() img_num = len(shape_list) poly_lists = [] for ino in range(img_num): - p_score = score_list[ino].transpose((1,2,0)).numpy() - p_border = border_list[ino].transpose((1,2,0)).numpy() - p_tvo = tvo_list[ino].transpose((1,2,0)).numpy() - p_tco = tco_list[ino].transpose((1,2,0)).numpy() + p_score = score_list[ino].transpose((1,2,0)) + p_border = border_list[ino].transpose((1,2,0)) + p_tvo = tvo_list[ino].transpose((1,2,0)) + p_tco = tco_list[ino].transpose((1,2,0)) src_h, src_w, ratio_h, ratio_w = shape_list[ino] poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 6f98ded8295dabbd5edf05913245e5d94d856689..d389ca393dc94a7ece69e6f59f999073ae4b1773 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -37,33 +37,51 @@ class TextDetector(object): def __init__(self, args): self.det_algorithm = args.det_algorithm self.use_zero_copy_run = args.use_zero_copy_run + pre_process_list = [{ + 'DetResizeForTest': { + 'limit_side_len': args.det_limit_side_len, + 'limit_type': args.det_limit_type + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] postprocess_params = {} if self.det_algorithm == "DB": - pre_process_list = [{ - 'DetResizeForTest': { - 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type - } - }, { - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' - } - }, { - 'ToCHWImage': None - }, { - 'KeepKeys': { - 'keep_keys': ['image', 'shape'] - } - }] postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = True + elif self.det_algorithm == "EAST": + postprocess_params['name'] = 'EASTPostProcess' + postprocess_params["score_thresh"] = args.det_east_score_thresh + postprocess_params["cover_thresh"] = args.det_east_cover_thresh + postprocess_params["nms_thresh"] = args.det_east_nms_thresh + elif self.det_algorithm == "SAST": + postprocess_params['name'] = 'SASTPostProcess' + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: + postprocess_params["sample_pts_num"] = 6 + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["sample_pts_num"] = 2 + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -149,12 +167,25 @@ class TextDetector(object): for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) - preds = outputs[0] - # preds = self.predictor(img) + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs[0] + preds['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + preds['f_border'] = outputs[0] + preds['f_score'] = outputs[1] + preds['f_tco'] = outputs[2] + preds['f_tvo'] = outputs[3] + else: + preds = outputs[0] + post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime return dt_boxes, elapse