提交 3f64d27b 编写于 作者: M MissPenguin

update inference for east & sast

上级 1f926f5f
...@@ -173,7 +173,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ ...@@ -173,7 +173,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
<a name="EAST文本检测模型推理"></a> <a name="EAST文本检测模型推理"></a>
### 3. EAST文本检测模型推理 ### 3. EAST文本检测模型推理
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换: 首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
...@@ -186,7 +186,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ...@@ -186,7 +186,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
``` ```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
(coming soon) ![](../imgs_results/det_res_img_10_east.jpg)
**注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 **注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
...@@ -194,7 +194,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ...@@ -194,7 +194,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
<a name="SAST文本检测模型推理"></a> <a name="SAST文本检测模型推理"></a>
### 4. SAST文本检测模型推理 ### 4. SAST文本检测模型推理
#### (1). 四边形文本检测模型(ICDAR2015) #### (1). 四边形文本检测模型(ICDAR2015)
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: 首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
...@@ -205,10 +205,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -205,10 +205,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
``` ```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
(coming soon) ![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). 弯曲文本检测模型(Total-Text) #### (2). 弯曲文本检测模型(Total-Text)
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换: 首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
...@@ -221,7 +221,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -221,7 +221,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
``` ```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
(coming soon) ![](../imgs_results/det_res_img623_sast.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 **注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
......
...@@ -179,7 +179,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -179,7 +179,7 @@ The visualized text detection results are saved to the `./inference_results` fol
<a name="EAST_DETECTION"></a> <a name="EAST_DETECTION"></a>
### 3. EAST TEXT DETECTION MODEL INFERENCE ### 3. EAST TEXT DETECTION MODEL INFERENCE
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
...@@ -192,7 +192,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ ...@@ -192,7 +192,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
(coming soon) ![](../imgs_results/det_res_img_10_east.jpg)
**Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. **Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
...@@ -200,7 +200,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -200,7 +200,7 @@ The visualized text detection results are saved to the `./inference_results` fol
<a name="SAST_DETECTION"></a> <a name="SAST_DETECTION"></a>
### 4. SAST TEXT DETECTION MODEL INFERENCE ### 4. SAST TEXT DETECTION MODEL INFERENCE
#### (1). Quadrangle text detection model (ICDAR2015) #### (1). Quadrangle text detection model (ICDAR2015)
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert: First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
...@@ -214,10 +214,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -214,10 +214,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
(coming soon) ![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). Curved text detection model (Total-Text) #### (2). Curved text detection model (Total-Text)
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert: First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
...@@ -231,7 +231,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -231,7 +231,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
(coming soon) ![](../imgs_results/det_res_img623_sast.jpg)
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases. **Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
......
...@@ -132,7 +132,8 @@ class DBPostProcess(object): ...@@ -132,7 +132,8 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, pred, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if isinstance(pred, paddle.Tensor):
pred = pred.numpy() pred = pred.numpy()
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
......
...@@ -19,12 +19,10 @@ from __future__ import print_function ...@@ -19,12 +19,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from .locality_aware_nms import nms_locality from .locality_aware_nms import nms_locality
import cv2 import cv2
import paddle
import os import os
import sys import sys
# __dir__ = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(__dir__)
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
class EASTPostProcess(object): class EASTPostProcess(object):
...@@ -113,11 +111,14 @@ class EASTPostProcess(object): ...@@ -113,11 +111,14 @@ class EASTPostProcess(object):
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score'] score_list = outs_dict['f_score']
geo_list = outs_dict['f_geo'] geo_list = outs_dict['f_geo']
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
geo_list = geo_list.numpy()
img_num = len(shape_list) img_num = len(shape_list)
dt_boxes_list = [] dt_boxes_list = []
for ino in range(img_num): for ino in range(img_num):
score = score_list[ino].numpy() score = score_list[ino]
geo = geo_list[ino].numpy() geo = geo_list[ino]
boxes = self.detect( boxes = self.detect(
score_map=score, score_map=score,
geo_map=geo, geo_map=geo,
......
...@@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..')) ...@@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..'))
import numpy as np import numpy as np
from .locality_aware_nms import nms_locality from .locality_aware_nms import nms_locality
# import lanms import paddle
import cv2 import cv2
import time import time
...@@ -276,14 +276,19 @@ class SASTPostProcess(object): ...@@ -276,14 +276,19 @@ class SASTPostProcess(object):
border_list = outs_dict['f_border'] border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo'] tvo_list = outs_dict['f_tvo']
tco_list = outs_dict['f_tco'] tco_list = outs_dict['f_tco']
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
img_num = len(shape_list) img_num = len(shape_list)
poly_lists = [] poly_lists = []
for ino in range(img_num): for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0)).numpy() p_score = score_list[ino].transpose((1,2,0))
p_border = border_list[ino].transpose((1,2,0)).numpy() p_border = border_list[ino].transpose((1,2,0))
p_tvo = tvo_list[ino].transpose((1,2,0)).numpy() p_tvo = tvo_list[ino].transpose((1,2,0))
p_tco = tco_list[ino].transpose((1,2,0)).numpy() p_tco = tco_list[ino].transpose((1,2,0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino] src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
......
...@@ -37,33 +37,51 @@ class TextDetector(object): ...@@ -37,33 +37,51 @@ class TextDetector(object):
def __init__(self, args): def __init__(self, args):
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {} postprocess_params = {}
if self.det_algorithm == "DB": if self.det_algorithm == "DB":
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params['name'] = 'DBPostProcess' postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True postprocess_params["use_dilation"] = True
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST":
postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon:
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -149,12 +167,25 @@ class TextDetector(object): ...@@ -149,12 +167,25 @@ class TextDetector(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0]
# preds = self.predictor(img) preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs[0]
preds['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs[0]
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
else:
preds['maps'] = outputs[0]
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, elapse return dt_boxes, elapse
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册