From 03895497fad84e35f678418186840ea3946bba9f Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Sun, 11 Apr 2021 18:43:42 +0800 Subject: [PATCH] add fast postprocess --- doc/doc_ch/pgnet.md | 9 ++- doc/doc_en/pgnet_en.md | 4 ++ .../utils/e2e_utils/extract_textpoint_fast.py | 9 ++- .../utils/e2e_utils/extract_textpoint_slow.py | 60 ++++++++++++++++++- ppocr/utils/e2e_utils/pgnet_pp_utils.py | 7 ++- 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md index d82bb796..71d17e56 100644 --- a/doc/doc_ch/pgnet.md +++ b/doc/doc_ch/pgnet.md @@ -146,7 +146,7 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img= ``` ### 预测推理 -#### (1).四边形文本检测模型(ICDAR2015) +#### (1). 四边形文本检测模型(ICDAR2015) 首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换: ``` wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar @@ -160,7 +160,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ![](../imgs_results/e2e_res_img_10_pgnet.jpg) -#### (2).弯曲文本检测模型(Total-Text) +#### (2). 弯曲文本检测模型(Total-Text) 对于弯曲文本样例 **PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令: @@ -170,3 +170,8 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im 可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: ![](../imgs_results/e2e_res_img623_pgnet.jpg) + +#### (3). 精度与FPS +|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS| +| --- | --- | --- | --- | --- | --- | --- | +|87.03|82.48|84.69|61.71|58.43|60.03|62.61| diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md index 0f47f0e6..b5865257 100644 --- a/doc/doc_en/pgnet_en.md +++ b/doc/doc_en/pgnet_en.md @@ -173,3 +173,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: ![](../imgs_results/e2e_res_img623_pgnet.jpg) +#### (3). Metric and FPS +|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS| +| --- | --- | --- | --- | --- | --- | --- | +|87.03|82.48|84.69|61.71|58.43|60.03|62.61| diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index 9635ac55..787cd301 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -21,7 +21,7 @@ import math import numpy as np from itertools import groupby -from cv2.ximgproc import thinning as thin +from skimage.morphology._skeletonize import thin def get_dict(character_dict_path): @@ -362,11 +362,10 @@ def generate_pivot_list_fast(p_score, """ p_score = p_score[0] f_direction = f_direction.transpose(1, 2, 0) - ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255, - cv2.THRESH_BINARY) - skeleton_map = thin(p_tcl_map.astype('uint8')) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map.astype(np.uint8)) instance_count, instance_label_map = cv2.connectedComponents( - skeleton_map, connectivity=8) + skeleton_map.astype(np.uint8), connectivity=8) # get TCL Instance all_pos_yxs = [] diff --git a/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py index 3c83fb46..db0c30e6 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_slow.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py @@ -21,7 +21,7 @@ import math import numpy as np from itertools import groupby -from cv2.ximgproc import thinning as thin +from skimage.morphology._skeletonize import thin def get_dict(character_dict_path): @@ -35,6 +35,64 @@ def get_dict(character_dict_path): return dict_character +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + def softmax(logits): """ logits: N x d diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index e1bc38cb..64bfd372 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -16,9 +16,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import paddle +import os +import sys +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) from extract_textpoint_slow import * -from extract_textpoint_fast import * +from extract_textpoint_fast import generate_pivot_list_fast, restore_poly class PGNet_PostProcess(object): -- GitLab