From d036c91af18e9616e6f1e587f2c55c3816526c52 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Sun, 11 Apr 2021 16:40:46 +0800 Subject: [PATCH] support two postprocess --- configs/e2e/e2e_r50_vd_pg.yml | 1 + doc/doc_ch/inference.md | 41 +- ppocr/data/imaug/label_ops.py | 16 +- ppocr/metrics/e2e_metric.py | 2 +- ppocr/postprocess/pg_postprocess.py | 124 +---- .../utils/e2e_utils/extract_textpoint_fast.py | 458 ++++++++++++++++++ ...textpoint.py => extract_textpoint_slow.py} | 16 +- ppocr/utils/e2e_utils/pgnet_pp_utils.py | 176 +++++++ 8 files changed, 665 insertions(+), 169 deletions(-) create mode 100644 ppocr/utils/e2e_utils/extract_textpoint_fast.py rename ppocr/utils/e2e_utils/{extract_textpoint.py => extract_textpoint_slow.py} (98%) create mode 100644 ppocr/utils/e2e_utils/pgnet_pp_utils.py diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 5a593ad8..e4d868f9 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -59,6 +59,7 @@ Optimizer: PostProcess: name: PGPostProcess score_thresh: 0.5 + mode: fast # fast or slow two ways Metric: name: E2EMetric gt_mat_dir: # the dir of gt_mat diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index 1288d906..0b082c56 100755 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -28,13 +28,10 @@ inference 模型(`paddle.jit.save`保存的模型) - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - [5. 多语言模型的推理](#多语言模型的推理) -- [四、端到端模型推理](#端到端模型推理) - - [1. PGNet端到端模型推理](#PGNet端到端模型推理) - -- [五、方向分类模型推理](#方向识别模型推理) +- [四、方向分类模型推理](#方向识别模型推理) - [1. 方向分类模型推理](#方向分类模型推理) -- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) +- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [2. 其他模型推理](#其他模型推理) @@ -362,38 +359,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" - Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904) ``` - -## 四、端到端模型推理 - -端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。 - -### 1. PGNet端到端模型推理 -#### (1). 四边形文本检测模型(ICDAR2015) -首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换: -``` -python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e -``` -**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令: -``` -python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False -``` -可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: - -![](../imgs_results/e2e_res_img_10_pgnet.jpg) - -#### (2). 弯曲文本检测模型(Total-Text) -和四边形文本检测模型共用一个推理模型 -**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令: -``` -python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True -``` -可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: - -![](../imgs_results/e2e_res_img623_pgnet.jpg) - - -## 五、方向分类模型推理 +## 四、方向分类模型推理 下面将介绍方向分类模型推理。 @@ -418,7 +385,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982] ``` -## 六、文本检测、方向分类和文字识别串联推理 +## 五、文本检测、方向分类和文字识别串联推理 ### 1. 超轻量中文OCR模型推理 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 47e0cbf0..cbb11009 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode): self.pad_num = len(self.dict) # the length to pad def __call__(self, data): - text_label_index_list, temp_text = [], [] texts = data['strs'] + temp_texts = [] for text in texts: text = text.lower() - temp_text = [] - for c_ in text: - if c_ in self.dict: - temp_text.append(self.dict[c_]) - temp_text = temp_text + [self.pad_num] * (self.max_text_len - - len(temp_text)) - text_label_index_list.append(temp_text) - data['strs'] = np.array(text_label_index_list) + text = self.encode(text) + if text is None: + return None + text = text + [self.pad_num] * (self.max_text_len - len(text)) + temp_texts.append(text) + data['strs'] = np.array(temp_texts) return data diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index ef14ad48..8a604192 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -19,7 +19,7 @@ from __future__ import print_function __all__ = ['E2EMetric'] from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results -from ppocr.utils.e2e_utils.extract_textpoint import get_dict +from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict class E2EMetric(object): diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index f9118d87..0b145518 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -22,10 +22,7 @@ import sys __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) - -from ppocr.utils.e2e_utils.extract_textpoint import * -from ppocr.utils.e2e_utils.visual import * -import paddle +from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess class PGPostProcess(object): @@ -33,10 +30,12 @@ class PGPostProcess(object): The post process for PGNet. """ - def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs): - self.Lexicon_Table = get_dict(character_dict_path) + def __init__(self, character_dict_path, valid_set, score_thresh, mode, + **kwargs): + self.character_dict_path = character_dict_path self.valid_set = valid_set self.score_thresh = score_thresh + self.mode = mode # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False @@ -44,113 +43,10 @@ class PGPostProcess(object): self.is_python35 = True def __call__(self, outs_dict, shape_list): - p_score = outs_dict['f_score'] - p_border = outs_dict['f_border'] - p_char = outs_dict['f_char'] - p_direction = outs_dict['f_direction'] - if isinstance(p_score, paddle.Tensor): - p_score = p_score[0].numpy() - p_border = p_border[0].numpy() - p_direction = p_direction[0].numpy() - p_char = p_char[0].numpy() + post = PGNet_PostProcess(self.character_dict_path, self.valid_set, + self.score_thresh, outs_dict, shape_list) + if self.mode == 'fast': + data = post.pg_postprocess_fast() else: - p_score = p_score[0] - p_border = p_border[0] - p_direction = p_direction[0] - p_char = p_char[0] - src_h, src_w, ratio_h, ratio_w = shape_list[0] - is_curved = self.valid_set == "totaltext" - instance_yxs_list = generate_pivot_list( - p_score, - p_char, - p_direction, - score_thresh=self.score_thresh, - is_backbone=True, - is_curved=is_curved) - p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0)) - char_seq_idx_set = [] - for i in range(len(instance_yxs_list)): - gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) - f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) - feature_seq = paddle.gather_nd(f_char_map, gather_info_lod) - feature_seq = np.expand_dims(feature_seq.numpy(), axis=0) - feature_len = [len(feature_seq[0])] - featyre_seq = paddle.to_tensor(feature_seq) - feature_len = np.array([feature_len]).astype(np.int64) - length = paddle.to_tensor(feature_len) - seq_pred = paddle.fluid.layers.ctc_greedy_decoder( - input=featyre_seq, blank=36, input_length=length) - seq_pred_str = seq_pred[0].numpy().tolist()[0] - seq_len = seq_pred[1].numpy()[0][0] - temp_t = [] - for c in seq_pred_str[:seq_len]: - temp_t.append(c) - char_seq_idx_set.append(temp_t) - seq_strs = [] - for char_idx_set in char_seq_idx_set: - pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) - seq_strs.append(pr_str) - poly_list = [] - keep_str_list = [] - all_point_list = [] - all_point_pair_list = [] - for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): - if len(yx_center_line) == 1: - yx_center_line.append(yx_center_line[-1]) - - offset_expand = 1.0 - if self.valid_set == 'totaltext': - offset_expand = 1.2 - - point_pair_list = [] - for batch_id, y, x in yx_center_line: - offset = p_border[:, y, x].reshape(2, 2) - if offset_expand != 1.0: - offset_length = np.linalg.norm( - offset, axis=1, keepdims=True) - expand_length = np.clip( - offset_length * (offset_expand - 1), - a_min=0.5, - a_max=3.0) - offset_detal = offset / offset_length * expand_length - offset = offset + offset_detal - ori_yx = np.array([y, x], dtype=np.float32) - point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( - [ratio_w, ratio_h]).reshape(-1, 2) - point_pair_list.append(point_pair) - - all_point_list.append([ - int(round(x * 4.0 / ratio_w)), - int(round(y * 4.0 / ratio_h)) - ]) - all_point_pair_list.append(point_pair.round().astype(np.int32) - .tolist()) - - detected_poly, pair_length_info = point_pair2poly(point_pair_list) - detected_poly = expand_poly_along_width( - detected_poly, shrink_ratio_of_width=0.2) - detected_poly[:, 0] = np.clip( - detected_poly[:, 0], a_min=0, a_max=src_w) - detected_poly[:, 1] = np.clip( - detected_poly[:, 1], a_min=0, a_max=src_h) - - if len(keep_str) < 2: - continue - - keep_str_list.append(keep_str) - detected_poly = np.round(detected_poly).astype('int32') - if self.valid_set == 'partvgg': - middle_point = len(detected_poly) // 2 - detected_poly = detected_poly[ - [0, middle_point - 1, middle_point, -1], :] - poly_list.append(detected_poly) - elif self.valid_set == 'totaltext': - poly_list.append(detected_poly) - else: - print('--> Not supported format.') - exit(-1) - data = { - 'points': poly_list, - 'strs': keep_str_list, - } + data = post.pg_postprocess_slow() return data diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py new file mode 100644 index 00000000..9635ac55 --- /dev/null +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -0,0 +1,458 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math + +import numpy as np +from itertools import groupby +from cv2.ximgproc import thinning as thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] + probs_seq = logits_seq + labels = np.argmax(probs_seq, axis=1) + dst_str = [k for k, v_ in groupby(labels) if k != C - 1] + detal = len(gather_info) // (pts_num - 1) + keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1] + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, + logits_map, + Lexicon_Table, + pts_num=6): + """ + CTC decoder using multiple processes. + """ + decoder_str = [] + decoder_xys = [] + for gather_info in gather_info_list: + if len(gather_info) < pts_num: + continue + dst_str, xys_list = instance_ctc_greedy_decoder( + gather_info, logits_map, pts_num=pts_num) + dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) + if len(dst_str_readable) < 2: + continue + decoder_str.append(dst_str_readable) + decoder_xys.append(xys_list) + return decoder_str, decoder_xys + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, + src_h, valid_set): + poly_list = [] + keep_str_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(keep_str) < 2: + print('--> too short, {}'.format(keep_str)) + continue + + offset_expand = 1.0 + if valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) * offset_expand + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + detected_poly = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + + keep_str_list.append(keep_str) + if valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + return poly_list, keep_str_list + + +def generate_pivot_list_fast(p_score, + p_char_maps, + f_direction, + Lexicon_Table, + score_thresh=0.5): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255, + cv2.THRESH_BINARY) + skeleton_map = thin(p_tcl_map.astype('uint8')) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map, connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + if len(pos_list) < 3: + continue + + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + all_pos_yxs.append(pos_list_sorted) + + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decoded_str, keep_yxs_list = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) + return keep_yxs_list, decoded_str + + +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint_slow.py similarity index 98% rename from ppocr/utils/e2e_utils/extract_textpoint.py rename to ppocr/utils/e2e_utils/extract_textpoint_slow.py index 975ca161..3c83fb46 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_slow.py @@ -21,7 +21,7 @@ import math import numpy as np from itertools import groupby -from skimage.morphology._skeletonize import thin +from cv2.ximgproc import thinning as thin def get_dict(character_dict_path): @@ -399,13 +399,13 @@ def generate_pivot_list_horizontal(p_score, return center_pos_yxs, end_points_yxs -def generate_pivot_list(p_score, - p_char_maps, - f_direction, - score_thresh=0.5, - is_backbone=False, - is_curved=True, - image_id=0): +def generate_pivot_list_slow(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): """ Warp all the function together. """ diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py new file mode 100644 index 00000000..e1bc38cb --- /dev/null +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle + +from extract_textpoint_slow import * +from extract_textpoint_fast import * + + +class PGNet_PostProcess(object): + # two different post-process + def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, + shape_list): + self.Lexicon_Table = get_dict(character_dict_path) + self.valid_set = valid_set + self.score_thresh = score_thresh + self.outs_dict = outs_dict + self.shape_list = shape_list + + def pg_postprocess_fast(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + instance_yxs_list, seq_strs = generate_pivot_list_fast( + p_score, + p_char, + p_direction, + self.Lexicon_Table, + score_thresh=self.score_thresh) + poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, + p_border, ratio_w, ratio_h, + src_w, src_h, self.valid_set) + data = { + 'points': poly_list, + 'strs': keep_str_list, + } + return data + + def pg_postprocess_slow(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + is_curved = self.valid_set == "totaltext" + instance_yxs_list = generate_pivot_list_slow( + p_score, + p_char, + p_direction, + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved) + p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0)) + char_seq_idx_set = [] + for i in range(len(instance_yxs_list)): + gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) + f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) + feature_seq = paddle.gather_nd(f_char_map, gather_info_lod) + feature_seq = np.expand_dims(feature_seq.numpy(), axis=0) + feature_len = [len(feature_seq[0])] + featyre_seq = paddle.to_tensor(feature_seq) + feature_len = np.array([feature_len]).astype(np.int64) + length = paddle.to_tensor(feature_len) + seq_pred = paddle.fluid.layers.ctc_greedy_decoder( + input=featyre_seq, blank=36, input_length=length) + seq_pred_str = seq_pred[0].numpy().tolist()[0] + seq_len = seq_pred[1].numpy()[0][0] + temp_t = [] + for c in seq_pred_str[:seq_len]: + temp_t.append(c) + char_seq_idx_set.append(temp_t) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + yx_center_line.append(yx_center_line[-1]) + + offset_expand = 1.0 + if self.valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + all_point_list.append([ + int(round(x * 4.0 / ratio_w)), + int(round(y * 4.0 / ratio_h)) + ]) + all_point_pair_list.append(point_pair.round().astype(np.int32) + .tolist()) + + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + continue + + keep_str_list.append(keep_str) + detected_poly = np.round(detected_poly).astype('int32') + if self.valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif self.valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + data = { + 'points': poly_list, + 'strs': keep_str_list, + } + return data -- GitLab