diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
old mode 100644
new mode 100755
index 663533c492ab5dc0bd22cc79bd95c9d1d194d854..10c01666404c1a66a14485ed30195954ed881b6f
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -173,7 +173,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
### 3. EAST文本检测模型推理
-首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换:
+首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
@@ -186,7 +186,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img_10_east.jpg)
**注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
@@ -194,7 +194,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
### 4. SAST文本检测模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
-首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
+首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
@@ -205,10 +205,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
-首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
+首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
@@ -221,7 +221,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-(coming soon)
+![](../imgs_results/det_res_img623_sast.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
old mode 100644
new mode 100755
index 411a733dd062cf347d7a2e5d5d067739bda36819..606565275ba243969bf919bfb91a0a2067a7a8cd
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -179,7 +179,7 @@ The visualized text detection results are saved to the `./inference_results` fol
### 3. EAST TEXT DETECTION MODEL INFERENCE
-First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
+First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
@@ -192,7 +192,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img_10_east.jpg)
**Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
@@ -200,7 +200,7 @@ The visualized text detection results are saved to the `./inference_results` fol
### 4. SAST TEXT DETECTION MODEL INFERENCE
#### (1). Quadrangle text detection model (ICDAR2015)
-First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
+First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
@@ -214,10 +214,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). Curved text detection model (Total-Text)
-First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert:
+First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
@@ -231,7 +231,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
-(coming soon)
+![](../imgs_results/det_res_img623_sast.jpg)
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
diff --git a/doc/imgs_results/det_res_img623_sast.jpg b/doc/imgs_results/det_res_img623_sast.jpg
index b2dd538f7729724a33516091d11c081c7c2c1bd7..af5e2d6e2c5643ee71a29bc015e8ae88a8d20058 100644
Binary files a/doc/imgs_results/det_res_img623_sast.jpg and b/doc/imgs_results/det_res_img623_sast.jpg differ
diff --git a/doc/imgs_results/det_res_img_10_east.jpg b/doc/imgs_results/det_res_img_10_east.jpg
index 400b9e6fcd75f886fc99964cd0793e2ff07693d2..908d077c3eabcb95eabf4afc54ce0bed1b54f355 100644
Binary files a/doc/imgs_results/det_res_img_10_east.jpg and b/doc/imgs_results/det_res_img_10_east.jpg differ
diff --git a/doc/imgs_results/det_res_img_10_sast.jpg b/doc/imgs_results/det_res_img_10_sast.jpg
index c63faf1354601f25cedb57a3b87f4467999f5457..702f773e68fe339e9acbc4d21c98cd0aa4536ef5 100644
Binary files a/doc/imgs_results/det_res_img_10_sast.jpg and b/doc/imgs_results/det_res_img_10_sast.jpg differ
diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py
index 16c789dcd7e9740ca8ddf613d0f2567c9af22820..b2deb3dc29613def43533d412872b65c9f589cce 100755
--- a/ppocr/postprocess/db_postprocess.py
+++ b/ppocr/postprocess/db_postprocess.py
@@ -132,7 +132,8 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
- def __call__(self, pred, shape_list):
+ def __call__(self, outs_dict, shape_list):
+ pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py
old mode 100644
new mode 100755
index 0b669405562aef9812b9771977bf82f362beb75e..ceee727aa3df052041aee925c6c856773c8a288e
--- a/ppocr/postprocess/east_postprocess.py
+++ b/ppocr/postprocess/east_postprocess.py
@@ -19,12 +19,10 @@ from __future__ import print_function
import numpy as np
from .locality_aware_nms import nms_locality
import cv2
+import paddle
import os
import sys
-# __dir__ = os.path.dirname(os.path.abspath(__file__))
-# sys.path.append(__dir__)
-# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
class EASTPostProcess(object):
@@ -113,11 +111,14 @@ class EASTPostProcess(object):
def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
geo_list = outs_dict['f_geo']
+ if isinstance(score_list, paddle.Tensor):
+ score_list = score_list.numpy()
+ geo_list = geo_list.numpy()
img_num = len(shape_list)
dt_boxes_list = []
for ino in range(img_num):
- score = score_list[ino].numpy()
- geo = geo_list[ino].numpy()
+ score = score_list[ino]
+ geo = geo_list[ino]
boxes = self.detect(
score_map=score,
geo_map=geo,
diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py
old mode 100644
new mode 100755
index 03b0e8f17d60a8863a2d1d5900e01dcd4874d5a9..f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1
--- a/ppocr/postprocess/sast_postprocess.py
+++ b/ppocr/postprocess/sast_postprocess.py
@@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
-# import lanms
+import paddle
import cv2
import time
@@ -276,14 +276,19 @@ class SASTPostProcess(object):
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
tco_list = outs_dict['f_tco']
+ if isinstance(score_list, paddle.Tensor):
+ score_list = score_list.numpy()
+ border_list = border_list.numpy()
+ tvo_list = tvo_list.numpy()
+ tco_list = tco_list.numpy()
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
- p_score = score_list[ino].transpose((1,2,0)).numpy()
- p_border = border_list[ino].transpose((1,2,0)).numpy()
- p_tvo = tvo_list[ino].transpose((1,2,0)).numpy()
- p_tco = tco_list[ino].transpose((1,2,0)).numpy()
+ p_score = score_list[ino].transpose((1,2,0))
+ p_border = border_list[ino].transpose((1,2,0))
+ p_tvo = tvo_list[ino].transpose((1,2,0))
+ p_tco = tco_list[ino].transpose((1,2,0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 6f98ded8295dabbd5edf05913245e5d94d856689..f07655a86ba637d84e1dc709d94a61c509929c00 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -37,33 +37,51 @@ class TextDetector(object):
def __init__(self, args):
self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
+ pre_process_list = [{
+ 'DetResizeForTest': {
+ 'limit_side_len': args.det_limit_side_len,
+ 'limit_type': args.det_limit_type
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }]
postprocess_params = {}
if self.det_algorithm == "DB":
- pre_process_list = [{
- 'DetResizeForTest': {
- 'limit_side_len': args.det_limit_side_len,
- 'limit_type': args.det_limit_type
- }
- }, {
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
- }
- }, {
- 'ToCHWImage': None
- }, {
- 'KeepKeys': {
- 'keep_keys': ['image', 'shape']
- }
- }]
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True
+ elif self.det_algorithm == "EAST":
+ postprocess_params['name'] = 'EASTPostProcess'
+ postprocess_params["score_thresh"] = args.det_east_score_thresh
+ postprocess_params["cover_thresh"] = args.det_east_cover_thresh
+ postprocess_params["nms_thresh"] = args.det_east_nms_thresh
+ elif self.det_algorithm == "SAST":
+ postprocess_params['name'] = 'SASTPostProcess'
+ postprocess_params["score_thresh"] = args.det_sast_score_thresh
+ postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
+ self.det_sast_polygon = args.det_sast_polygon
+ if self.det_sast_polygon:
+ postprocess_params["sample_pts_num"] = 6
+ postprocess_params["expand_scale"] = 1.2
+ postprocess_params["shrink_ratio_of_width"] = 0.2
+ else:
+ postprocess_params["sample_pts_num"] = 2
+ postprocess_params["expand_scale"] = 1.0
+ postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
@@ -149,12 +167,25 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
- preds = outputs[0]
- # preds = self.predictor(img)
+ preds = {}
+ if self.det_algorithm == "EAST":
+ preds['f_geo'] = outputs[0]
+ preds['f_score'] = outputs[1]
+ elif self.det_algorithm == 'SAST':
+ preds['f_border'] = outputs[0]
+ preds['f_score'] = outputs[1]
+ preds['f_tco'] = outputs[2]
+ preds['f_tvo'] = outputs[3]
+ else:
+ preds['maps'] = outputs[0]
+
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
- dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+ if self.det_algorithm == "SAST" and self.det_sast_polygon:
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
+ else:
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse