From d036c91af18e9616e6f1e587f2c55c3816526c52 Mon Sep 17 00:00:00 2001
From: Jethong <1147925384@qq.com>
Date: Sun, 11 Apr 2021 16:40:46 +0800
Subject: [PATCH] support two postprocess
---
configs/e2e/e2e_r50_vd_pg.yml | 1 +
doc/doc_ch/inference.md | 41 +-
ppocr/data/imaug/label_ops.py | 16 +-
ppocr/metrics/e2e_metric.py | 2 +-
ppocr/postprocess/pg_postprocess.py | 124 +----
.../utils/e2e_utils/extract_textpoint_fast.py | 458 ++++++++++++++++++
...textpoint.py => extract_textpoint_slow.py} | 16 +-
ppocr/utils/e2e_utils/pgnet_pp_utils.py | 176 +++++++
8 files changed, 665 insertions(+), 169 deletions(-)
create mode 100644 ppocr/utils/e2e_utils/extract_textpoint_fast.py
rename ppocr/utils/e2e_utils/{extract_textpoint.py => extract_textpoint_slow.py} (98%)
create mode 100644 ppocr/utils/e2e_utils/pgnet_pp_utils.py
diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml
index 5a593ad8..e4d868f9 100644
--- a/configs/e2e/e2e_r50_vd_pg.yml
+++ b/configs/e2e/e2e_r50_vd_pg.yml
@@ -59,6 +59,7 @@ Optimizer:
PostProcess:
name: PGPostProcess
score_thresh: 0.5
+ mode: fast # fast or slow two ways
Metric:
name: E2EMetric
gt_mat_dir: # the dir of gt_mat
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 1288d906..0b082c56 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -28,13 +28,10 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)
-- [四、端到端模型推理](#端到端模型推理)
- - [1. PGNet端到端模型推理](#PGNet端到端模型推理)
-
-- [五、方向分类模型推理](#方向识别模型推理)
+- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
-- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
+- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理)
@@ -362,38 +359,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
```
-
-## 四、端到端模型推理
-
-端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
-
-### 1. PGNet端到端模型推理
-#### (1). 四边形文本检测模型(ICDAR2015)
-首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
-```
-python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
-```
-**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
-```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
-```
-可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
-
-![](../imgs_results/e2e_res_img_10_pgnet.jpg)
-
-#### (2). 弯曲文本检测模型(Total-Text)
-和四边形文本检测模型共用一个推理模型
-**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
-```
-python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
-```
-可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
-
-![](../imgs_results/e2e_res_img623_pgnet.jpg)
-
-
-## 五、方向分类模型推理
+## 四、方向分类模型推理
下面将介绍方向分类模型推理。
@@ -418,7 +385,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
```
-## 六、文本检测、方向分类和文字识别串联推理
+## 五、文本检测、方向分类和文字识别串联推理
### 1. 超轻量中文OCR模型推理
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 47e0cbf0..cbb11009 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode):
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data):
- text_label_index_list, temp_text = [], []
texts = data['strs']
+ temp_texts = []
for text in texts:
text = text.lower()
- temp_text = []
- for c_ in text:
- if c_ in self.dict:
- temp_text.append(self.dict[c_])
- temp_text = temp_text + [self.pad_num] * (self.max_text_len -
- len(temp_text))
- text_label_index_list.append(temp_text)
- data['strs'] = np.array(text_label_index_list)
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = text + [self.pad_num] * (self.max_text_len - len(text))
+ temp_texts.append(text)
+ data['strs'] = np.array(temp_texts)
return data
diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py
index ef14ad48..8a604192 100644
--- a/ppocr/metrics/e2e_metric.py
+++ b/ppocr/metrics/e2e_metric.py
@@ -19,7 +19,7 @@ from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
-from ppocr.utils.e2e_utils.extract_textpoint import get_dict
+from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object):
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
index f9118d87..0b145518 100644
--- a/ppocr/postprocess/pg_postprocess.py
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -22,10 +22,7 @@ import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
-
-from ppocr.utils.e2e_utils.extract_textpoint import *
-from ppocr.utils.e2e_utils.visual import *
-import paddle
+from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
class PGPostProcess(object):
@@ -33,10 +30,12 @@ class PGPostProcess(object):
The post process for PGNet.
"""
- def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
- self.Lexicon_Table = get_dict(character_dict_path)
+ def __init__(self, character_dict_path, valid_set, score_thresh, mode,
+ **kwargs):
+ self.character_dict_path = character_dict_path
self.valid_set = valid_set
self.score_thresh = score_thresh
+ self.mode = mode
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
@@ -44,113 +43,10 @@ class PGPostProcess(object):
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
- p_score = outs_dict['f_score']
- p_border = outs_dict['f_border']
- p_char = outs_dict['f_char']
- p_direction = outs_dict['f_direction']
- if isinstance(p_score, paddle.Tensor):
- p_score = p_score[0].numpy()
- p_border = p_border[0].numpy()
- p_direction = p_direction[0].numpy()
- p_char = p_char[0].numpy()
+ post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
+ self.score_thresh, outs_dict, shape_list)
+ if self.mode == 'fast':
+ data = post.pg_postprocess_fast()
else:
- p_score = p_score[0]
- p_border = p_border[0]
- p_direction = p_direction[0]
- p_char = p_char[0]
- src_h, src_w, ratio_h, ratio_w = shape_list[0]
- is_curved = self.valid_set == "totaltext"
- instance_yxs_list = generate_pivot_list(
- p_score,
- p_char,
- p_direction,
- score_thresh=self.score_thresh,
- is_backbone=True,
- is_curved=is_curved)
- p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
- char_seq_idx_set = []
- for i in range(len(instance_yxs_list)):
- gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
- f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
- feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
- feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
- feature_len = [len(feature_seq[0])]
- featyre_seq = paddle.to_tensor(feature_seq)
- feature_len = np.array([feature_len]).astype(np.int64)
- length = paddle.to_tensor(feature_len)
- seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
- input=featyre_seq, blank=36, input_length=length)
- seq_pred_str = seq_pred[0].numpy().tolist()[0]
- seq_len = seq_pred[1].numpy()[0][0]
- temp_t = []
- for c in seq_pred_str[:seq_len]:
- temp_t.append(c)
- char_seq_idx_set.append(temp_t)
- seq_strs = []
- for char_idx_set in char_seq_idx_set:
- pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
- seq_strs.append(pr_str)
- poly_list = []
- keep_str_list = []
- all_point_list = []
- all_point_pair_list = []
- for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
- if len(yx_center_line) == 1:
- yx_center_line.append(yx_center_line[-1])
-
- offset_expand = 1.0
- if self.valid_set == 'totaltext':
- offset_expand = 1.2
-
- point_pair_list = []
- for batch_id, y, x in yx_center_line:
- offset = p_border[:, y, x].reshape(2, 2)
- if offset_expand != 1.0:
- offset_length = np.linalg.norm(
- offset, axis=1, keepdims=True)
- expand_length = np.clip(
- offset_length * (offset_expand - 1),
- a_min=0.5,
- a_max=3.0)
- offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
- point_pair_list.append(point_pair)
-
- all_point_list.append([
- int(round(x * 4.0 / ratio_w)),
- int(round(y * 4.0 / ratio_h))
- ])
- all_point_pair_list.append(point_pair.round().astype(np.int32)
- .tolist())
-
- detected_poly, pair_length_info = point_pair2poly(point_pair_list)
- detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
- detected_poly[:, 0] = np.clip(
- detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(
- detected_poly[:, 1], a_min=0, a_max=src_h)
-
- if len(keep_str) < 2:
- continue
-
- keep_str_list.append(keep_str)
- detected_poly = np.round(detected_poly).astype('int32')
- if self.valid_set == 'partvgg':
- middle_point = len(detected_poly) // 2
- detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
- poly_list.append(detected_poly)
- elif self.valid_set == 'totaltext':
- poly_list.append(detected_poly)
- else:
- print('--> Not supported format.')
- exit(-1)
- data = {
- 'points': poly_list,
- 'strs': keep_str_list,
- }
+ data = post.pg_postprocess_slow()
return data
diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py
new file mode 100644
index 00000000..9635ac55
--- /dev/null
+++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py
@@ -0,0 +1,458 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains various CTC decoders."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from cv2.ximgproc import thinning as thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
+ _, _, C = logits_map.shape
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)]
+ probs_seq = logits_seq
+ labels = np.argmax(probs_seq, axis=1)
+ dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
+ detal = len(gather_info) // (pts_num - 1)
+ keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list,
+ logits_map,
+ Lexicon_Table,
+ pts_num=6):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_str = []
+ decoder_xys = []
+ for gather_info in gather_info_list:
+ if len(gather_info) < pts_num:
+ continue
+ dst_str, xys_list = instance_ctc_greedy_decoder(
+ gather_info, logits_map, pts_num=pts_num)
+ dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
+ if len(dst_str_readable) < 2:
+ continue
+ decoder_str.append(dst_str_readable)
+ decoder_xys.append(xys_list)
+ return decoder_str, decoder_xys
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2)
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
+ src_h, valid_set):
+ poly_list = []
+ keep_str_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(keep_str) < 2:
+ print('--> too short, {}'.format(keep_str))
+ continue
+
+ offset_expand = 1.0
+ if valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2) * offset_expand
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ detected_poly = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ keep_str_list.append(keep_str)
+ if valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif valid_set == 'totaltext':
+ poly_list.append(detected_poly)
+ else:
+ print('--> Not supported format.')
+ exit(-1)
+ return poly_list, keep_str_list
+
+
+def generate_pivot_list_fast(p_score,
+ p_char_maps,
+ f_direction,
+ Lexicon_Table,
+ score_thresh=0.5):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255,
+ cv2.THRESH_BINARY)
+ skeleton_map = thin(p_tcl_map.astype('uint8'))
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map, connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ if len(pos_list) < 3:
+ continue
+
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ all_pos_yxs.append(pos_list_sorted)
+
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decoded_str, keep_yxs_list = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
+ return keep_yxs_list, decoded_str
+
+
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py
similarity index 98%
rename from ppocr/utils/e2e_utils/extract_textpoint.py
rename to ppocr/utils/e2e_utils/extract_textpoint_slow.py
index 975ca161..3c83fb46 100644
--- a/ppocr/utils/e2e_utils/extract_textpoint.py
+++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py
@@ -21,7 +21,7 @@ import math
import numpy as np
from itertools import groupby
-from skimage.morphology._skeletonize import thin
+from cv2.ximgproc import thinning as thin
def get_dict(character_dict_path):
@@ -399,13 +399,13 @@ def generate_pivot_list_horizontal(p_score,
return center_pos_yxs, end_points_yxs
-def generate_pivot_list(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- is_curved=True,
- image_id=0):
+def generate_pivot_list_slow(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0):
"""
Warp all the function together.
"""
diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py
new file mode 100644
index 00000000..e1bc38cb
--- /dev/null
+++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import paddle
+
+from extract_textpoint_slow import *
+from extract_textpoint_fast import *
+
+
+class PGNet_PostProcess(object):
+ # two different post-process
+ def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
+ shape_list):
+ self.Lexicon_Table = get_dict(character_dict_path)
+ self.valid_set = valid_set
+ self.score_thresh = score_thresh
+ self.outs_dict = outs_dict
+ self.shape_list = shape_list
+
+ def pg_postprocess_fast(self):
+ p_score = self.outs_dict['f_score']
+ p_border = self.outs_dict['f_border']
+ p_char = self.outs_dict['f_char']
+ p_direction = self.outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ instance_yxs_list, seq_strs = generate_pivot_list_fast(
+ p_score,
+ p_char,
+ p_direction,
+ self.Lexicon_Table,
+ score_thresh=self.score_thresh)
+ poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
+ p_border, ratio_w, ratio_h,
+ src_w, src_h, self.valid_set)
+ data = {
+ 'points': poly_list,
+ 'strs': keep_str_list,
+ }
+ return data
+
+ def pg_postprocess_slow(self):
+ p_score = self.outs_dict['f_score']
+ p_border = self.outs_dict['f_border']
+ p_char = self.outs_dict['f_char']
+ p_direction = self.outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ is_curved = self.valid_set == "totaltext"
+ instance_yxs_list = generate_pivot_list_slow(
+ p_score,
+ p_char,
+ p_direction,
+ score_thresh=self.score_thresh,
+ is_backbone=True,
+ is_curved=is_curved)
+ p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
+ char_seq_idx_set = []
+ for i in range(len(instance_yxs_list)):
+ gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
+ f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
+ feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
+ feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
+ feature_len = [len(feature_seq[0])]
+ featyre_seq = paddle.to_tensor(feature_seq)
+ feature_len = np.array([feature_len]).astype(np.int64)
+ length = paddle.to_tensor(feature_len)
+ seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
+ input=featyre_seq, blank=36, input_length=length)
+ seq_pred_str = seq_pred[0].numpy().tolist()[0]
+ seq_len = seq_pred[1].numpy()[0][0]
+ temp_t = []
+ for c in seq_pred_str[:seq_len]:
+ temp_t.append(c)
+ char_seq_idx_set.append(temp_t)
+ seq_strs = []
+ for char_idx_set in char_seq_idx_set:
+ pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
+ seq_strs.append(pr_str)
+ poly_list = []
+ keep_str_list = []
+ all_point_list = []
+ all_point_pair_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(yx_center_line) == 1:
+ yx_center_line.append(yx_center_line[-1])
+
+ offset_expand = 1.0
+ if self.valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for batch_id, y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2)
+ if offset_expand != 1.0:
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
+ offset_detal = offset / offset_length * expand_length
+ offset = offset + offset_detal
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ all_point_list.append([
+ int(round(x * 4.0 / ratio_w)),
+ int(round(y * 4.0 / ratio_h))
+ ])
+ all_point_pair_list.append(point_pair.round().astype(np.int32)
+ .tolist())
+
+ detected_poly, pair_length_info = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ if len(keep_str) < 2:
+ continue
+
+ keep_str_list.append(keep_str)
+ detected_poly = np.round(detected_poly).astype('int32')
+ if self.valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif self.valid_set == 'totaltext':
+ poly_list.append(detected_poly)
+ else:
+ print('--> Not supported format.')
+ exit(-1)
+ data = {
+ 'points': poly_list,
+ 'strs': keep_str_list,
+ }
+ return data
--
GitLab