diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
old mode 100644
new mode 100755
index c61a9234bb29d1562112c47b5ca34118c12331d6..997bc1f1978b93e45472b5698341fa91dd9e7e84
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -17,17 +17,17 @@ PaddleOCR开源的文本检测算法列表:
|模型|骨干网络|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- |
-|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[下载链接 (coming soon)](link)|
-|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接 (coming soon)](coming soon)|
-|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
-|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
-|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接 (coming soon)](link)|
+|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
+|EAST|MobileNetV3|78.24%|79.15%|78.69%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
+|DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
+|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
+|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))|
在Total-text文本检测公开数据集上,算法效果如下:
|模型|骨干网络|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- |
-|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[下载链接 (coming soon)](link)|
+|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
@@ -47,13 +47,10 @@ PaddleOCR基于动态图开源的文本识别算法列表:
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
-| --- | --- | --- | --- | --- |
-|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
+|-|-|-|-|-|
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
-|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
+|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
-|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接 (coming soon )]()|
-|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[下载链接 (coming soon )]()|
-
+|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
old mode 100644
new mode 100755
index 962b4ae748922ce6067a7ee4e609366ba1f3e01d..aea7ff1de242dec75cae26a2bf3d6838d7559882
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -180,7 +180,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
### 3. EAST文本检测模型推理
-首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换:
+首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
@@ -193,7 +193,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img_10_east.jpg)
**注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
@@ -201,7 +201,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
### 4. SAST文本检测模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
-首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
+首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
@@ -212,10 +212,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
-首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
+首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
@@ -228,7 +228,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img623_sast.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
old mode 100644
new mode 100755
index 847b9288b96adf8c32b6486cb2798b838745c85f..6e4140e22813c0965009ffdef2eefddcec77489d
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -19,17 +19,17 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|Model|Backbone|precision|recall|Hmean|Download link|
| --- | --- | --- | --- | --- | --- |
-|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[download link (coming soon)](link)|
-|EAST|MobileNetV3|81.67%|79.83%|80.74%|[download link (coming soon)](coming soon)|
-|DB|ResNet50_vd|83.79%|80.65%|82.19%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
-|DB|MobileNetV3|75.92%|73.18%|74.53%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
-|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[download link (coming soon)](link)|
+|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
+|EAST|MobileNetV3|78.24%|79.15%|78.69%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
+|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
+|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
+|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))|
On Total-Text dataset, the text detection result is as follows:
|Model|Backbone|precision|recall|Hmean|Download link|
| --- | --- | --- | --- | --- | --- |
-|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[download link (coming soon)](link)|
+|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
@@ -48,14 +48,10 @@ PaddleOCR open-source text recognition algorithms list:
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|Model|Backbone|Avg Accuracy|Module combination|Download link|
-| --- | --- | --- | --- | --- |
-|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
-|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
-|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
-|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
-|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[download link (coming soon )]()|
-|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[download link (coming soon )]()|
-
-
+|-|-|-|-|-|
+|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
+|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
+|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
+|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
diff --git a/doc/doc_en/benchmark_en.md b/doc/doc_en/benchmark_en.md
old mode 100644
new mode 100755
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
old mode 100644
new mode 100755
index 85dc261c89098b7579336b429d7663cb1c0864f5..db86b109d1a13d00aab833aa31d0279622e7c7f8
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -187,7 +187,7 @@ The visualized text detection results are saved to the `./inference_results` fol
### 3. EAST TEXT DETECTION MODEL INFERENCE
-First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
+First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
@@ -200,7 +200,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img_10_east.jpg)
**Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
@@ -208,7 +208,7 @@ The visualized text detection results are saved to the `./inference_results` fol
### 4. SAST TEXT DETECTION MODEL INFERENCE
#### (1). Quadrangle text detection model (ICDAR2015)
-First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
+First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
@@ -222,10 +222,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). Curved text detection model (Total-Text)
-First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert:
+First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
@@ -239,7 +239,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img623_sast.jpg)
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
diff --git a/doc/imgs_results/det_res_img623_sast.jpg b/doc/imgs_results/det_res_img623_sast.jpg
index b2dd538f7729724a33516091d11c081c7c2c1bd7..af5e2d6e2c5643ee71a29bc015e8ae88a8d20058 100644
Binary files a/doc/imgs_results/det_res_img623_sast.jpg and b/doc/imgs_results/det_res_img623_sast.jpg differ
diff --git a/doc/imgs_results/det_res_img_10_east.jpg b/doc/imgs_results/det_res_img_10_east.jpg
index 400b9e6fcd75f886fc99964cd0793e2ff07693d2..908d077c3eabcb95eabf4afc54ce0bed1b54f355 100644
Binary files a/doc/imgs_results/det_res_img_10_east.jpg and b/doc/imgs_results/det_res_img_10_east.jpg differ
diff --git a/doc/imgs_results/det_res_img_10_sast.jpg b/doc/imgs_results/det_res_img_10_sast.jpg
index c63faf1354601f25cedb57a3b87f4467999f5457..702f773e68fe339e9acbc4d21c98cd0aa4536ef5 100644
Binary files a/doc/imgs_results/det_res_img_10_sast.jpg and b/doc/imgs_results/det_res_img_10_sast.jpg differ
diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py
old mode 100644
new mode 100755
index 0b669405562aef9812b9771977bf82f362beb75e..ceee727aa3df052041aee925c6c856773c8a288e
--- a/ppocr/postprocess/east_postprocess.py
+++ b/ppocr/postprocess/east_postprocess.py
@@ -19,12 +19,10 @@ from __future__ import print_function
import numpy as np
from .locality_aware_nms import nms_locality
import cv2
+import paddle
import os
import sys
-# __dir__ = os.path.dirname(os.path.abspath(__file__))
-# sys.path.append(__dir__)
-# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
class EASTPostProcess(object):
@@ -113,11 +111,14 @@ class EASTPostProcess(object):
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
geo_list = outs_dict['f_geo']
+ if isinstance(score_list, paddle.Tensor):
+ score_list = score_list.numpy()
+ geo_list = geo_list.numpy()
img_num = len(shape_list)
dt_boxes_list = []
for ino in range(img_num):
- score = score_list[ino].numpy()
- geo = geo_list[ino].numpy()
+ score = score_list[ino]
+ geo = geo_list[ino]
boxes = self.detect(
score_map=score,
geo_map=geo,
diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py
old mode 100644
new mode 100755
index 03b0e8f17d60a8863a2d1d5900e01dcd4874d5a9..f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1
--- a/ppocr/postprocess/sast_postprocess.py
+++ b/ppocr/postprocess/sast_postprocess.py
@@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
-# import lanms
+import paddle
import cv2
import time
@@ -276,14 +276,19 @@ class SASTPostProcess(object):
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
tco_list = outs_dict['f_tco']
+ if isinstance(score_list, paddle.Tensor):
+ score_list = score_list.numpy()
+ border_list = border_list.numpy()
+ tvo_list = tvo_list.numpy()
+ tco_list = tco_list.numpy()
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
- p_score = score_list[ino].transpose((1,2,0)).numpy()
- p_border = border_list[ino].transpose((1,2,0)).numpy()
- p_tvo = tvo_list[ino].transpose((1,2,0)).numpy()
- p_tco = tco_list[ino].transpose((1,2,0)).numpy()
+ p_score = score_list[ino].transpose((1,2,0))
+ p_border = border_list[ino].transpose((1,2,0))
+ p_tvo = tvo_list[ino].transpose((1,2,0))
+ p_tco = tco_list[ino].transpose((1,2,0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 6f98ded8295dabbd5edf05913245e5d94d856689..d389ca393dc94a7ece69e6f59f999073ae4b1773 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -37,33 +37,51 @@ class TextDetector(object):
def __init__(self, args):
self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
+ pre_process_list = [{
+ 'DetResizeForTest': {
+ 'limit_side_len': args.det_limit_side_len,
+ 'limit_type': args.det_limit_type
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }]
postprocess_params = {}
if self.det_algorithm == "DB":
- pre_process_list = [{
- 'DetResizeForTest': {
- 'limit_side_len': args.det_limit_side_len,
- 'limit_type': args.det_limit_type
- }
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image', 'shape']
- }
- }]
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True
+ elif self.det_algorithm == "EAST":
+ postprocess_params['name'] = 'EASTPostProcess'
+ postprocess_params["score_thresh"] = args.det_east_score_thresh
+ postprocess_params["cover_thresh"] = args.det_east_cover_thresh
+ postprocess_params["nms_thresh"] = args.det_east_nms_thresh
+ elif self.det_algorithm == "SAST":
+ postprocess_params['name'] = 'SASTPostProcess'
+ postprocess_params["score_thresh"] = args.det_sast_score_thresh
+ postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
+ self.det_sast_polygon = args.det_sast_polygon
+ if self.det_sast_polygon:
+ postprocess_params["sample_pts_num"] = 6
+ postprocess_params["expand_scale"] = 1.2
+ postprocess_params["shrink_ratio_of_width"] = 0.2
+ else:
+ postprocess_params["sample_pts_num"] = 2
+ postprocess_params["expand_scale"] = 1.0
+ postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
@@ -149,12 +167,25 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
- preds = outputs[0]
- # preds = self.predictor(img)
+ preds = {}
+ if self.det_algorithm == "EAST":
+ preds['f_geo'] = outputs[0]
+ preds['f_score'] = outputs[1]
+ elif self.det_algorithm == 'SAST':
+ preds['f_border'] = outputs[0]
+ preds['f_score'] = outputs[1]
+ preds['f_tco'] = outputs[2]
+ preds['f_tvo'] = outputs[3]
+ else:
+ preds = outputs[0]
+
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
- dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+ if self.det_algorithm == "SAST" and self.det_sast_polygon:
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
+ else:
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse