From b3f9f681d936489a27ebc47ffb02e2228a34c6fb Mon Sep 17 00:00:00 2001 From: ToddBear <43341135+ToddBear@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:12:01 +0800 Subject: [PATCH] =?UTF-8?q?CV=E5=A5=97=E4=BB=B6=E5=BB=BA=E8=AE=BE=E4=B8=93?= =?UTF-8?q?=E9=A1=B9=E6=B4=BB=E5=8A=A8=20-=20=E6=96=87=E5=AD=97=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E8=BF=94=E5=9B=9E=E5=8D=95=E5=AD=97=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E5=9D=90=E6=A0=87=20(#10515)=20(#10537)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modification of return word box * update_implements * Update rec_postprocess.py * Update utility.py --- ppocr/postprocess/rec_postprocess.py | 78 ++++++++++++++++++++++++++-- ppstructure/predict_system.py | 26 +++++++--- ppstructure/utility.py | 62 ++++++++++++++++++++++ tools/infer/predict_rec.py | 10 +++- tools/infer/predict_system.py | 2 +- tools/infer/utility.py | 4 ++ 6 files changed, 167 insertions(+), 15 deletions(-) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index f64ea1ce..ce2e9f8b 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -67,7 +67,66 @@ class BaseRecLabelDecode(object): def add_special_char(self, dict_character): return dict_character - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + def get_word_info(self, text, selection): + """ + Group the decoded characters and record the corresponding decoded positions. + + Args: + text: the decoded text + selection: the bool array that identifies which columns of features are decoded as non-separated characters + Returns: + word_list: list of the grouped words + word_col_list: list of decoding positions corresponding to each character in the grouped word + state_list: list of marker to identify the type of grouping words, including two types of grouping words: + - 'cn': continous chinese characters (e.g., 你好啊) + - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) + The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.). + """ + state = None + word_content = [] + word_col_content = [] + word_list = [] + word_col_list = [] + state_list = [] + valid_col = np.where(selection==True)[0] + + for c_i, char in enumerate(text): + if '\u4e00' <= char <= '\u9fff': + c_state = 'cn' + elif bool(re.search('[a-zA-Z0-9]', char)): + c_state = 'en&num' + else: + c_state = 'splitter' + + if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number + c_state = 'en&num' + if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art' + c_state = 'en&num' + + if state == None: + state = c_state + + if state != c_state: + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + word_content = [] + word_col_content = [] + state = c_state + + if state != "splitter": + word_content.append(char) + word_col_content.append(valid_col[c_i]) + + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + + return word_list, word_col_list, state_list + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False): """ convert text-index into text-label. """ result_list = [] ignored_tokens = self.get_ignored_tokens() @@ -95,8 +154,12 @@ class BaseRecLabelDecode(object): if self.reverse: # for arabic rec text = self.pred_reverse(text) - - result_list.append((text, np.mean(conf_list).tolist())) + + if return_word_box: + word_list, word_col_list, state_list = self.get_word_info(text, selection) + result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list])) + else: + result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): @@ -111,14 +174,19 @@ class CTCLabelDecode(BaseRecLabelDecode): super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char) - def __call__(self, preds, label=None, *args, **kwargs): + def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs): if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box) + if return_word_box: + for rec_idx, rec in enumerate(text): + wh_ratio = kwargs['wh_ratio_list'][rec_idx] + max_wh_ratio = kwargs['max_wh_ratio'] + rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio) if label is None: return text label = self.decode(label) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index b32b7062..b8b87168 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -34,7 +34,7 @@ from ppocr.utils.visual import draw_ser_results, draw_re_results from tools.infer.predict_system import TextSystem from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.table.predict_table import TableSystem, to_excel -from ppstructure.utility import parse_args, draw_structure_result +from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box logger = get_logger() @@ -79,6 +79,8 @@ class StructureSystem(object): from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor self.kie_predictor = SerRePredictor(args) + self.return_word_box = args.return_word_box + def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): time_dict = { 'image_orientation': 0, @@ -156,17 +158,27 @@ class StructureSystem(object): ] res = [] for box, rec_res in zip(filter_boxes, filter_rec_res): - rec_str, rec_conf = rec_res + rec_str, rec_conf = rec_res[0], rec_res[1] for token in style_token: if token in rec_str: rec_str = rec_str.replace(token, '') if not self.recovery: box += [x1, y1] - res.append({ - 'text': rec_str, - 'confidence': float(rec_conf), - 'text_region': box.tolist() - }) + if self.return_word_box: + word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2]) + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist(), + 'text_word': word_box_content_list, + 'text_word_region': word_box_list + }) + else: + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist() + }) res_list.append({ 'type': region['label'].lower(), 'bbox': [x1, y1, x2, y2], diff --git a/ppstructure/utility.py b/ppstructure/utility.py index aa34ee17..e51fdf0d 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -15,8 +15,13 @@ import random import ast from PIL import Image, ImageDraw, ImageFont import numpy as np +<<<<<<< HEAD from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args +======= +from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args +import math +>>>>>>> 1e11f254 (CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515)) def init_args(): parser = infer_args() @@ -152,6 +157,63 @@ def draw_structure_result(image, result, font_path): txts.append(text_result['text']) scores.append(text_result['confidence']) + if 'text_word_region' in text_result: + for word_region in text_result['text_word_region']: + char_box = word_region + box_height = int( + math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2)) + box_width = int( + math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2)) + if box_height == 0 or box_width == 0: + continue + boxes.append(word_region) + txts.append("") + scores.append(1.0) + im_show = draw_ocr_box_txt( img_layout, boxes, txts, scores, font_path=font_path, drop_score=0) return im_show + +def cal_ocr_word_box(rec_str, box, rec_word_info): + ''' Calculate the detection frame for each word based on the results of recognition and detection of ocr''' + + col_num, word_list, word_col_list, state_list = rec_word_info + box = box.tolist() + bbox_x_start = box[0][0] + bbox_x_end = box[1][0] + bbox_y_start = box[0][1] + bbox_y_end = box[2][1] + + cell_width = (bbox_x_end - bbox_x_start)/col_num + + word_box_list = [] + word_box_content_list = [] + cn_width_list = [] + cn_col_list = [] + for word, word_col, state in zip(word_list, word_col_list, state_list): + if state == 'cn': + if len(word_col) != 1: + char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width + char_width = char_seq_length/(len(word_col)-1) + cn_width_list.append(char_width) + cn_col_list += word_col + word_box_content_list += word + else: + cell_x_start = bbox_x_start + int(word_col[0] * cell_width) + cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width) + cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end)) + word_box_list.append(cell) + word_box_content_list.append("".join(word)) + if len(cn_col_list) != 0: + if len(cn_width_list) != 0: + avg_char_width = np.mean(cn_width_list) + else: + avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str) + for center_idx in cn_col_list: + center_x = (center_idx+0.5)*cell_width + cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start + cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start + cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end)) + word_box_list.append(cell) + + return word_box_content_list, word_box_list \ No newline at end of file diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 991612ab..44232c4b 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -123,6 +123,7 @@ class TextRecognizer(object): "use_space_char": args.use_space_char } self.postprocess_op = build_post_process(postprocess_params) + self.postprocess_params = postprocess_params self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) self.benchmark = args.benchmark @@ -146,6 +147,7 @@ class TextRecognizer(object): ], warmup=0, logger=logger) + self.return_word_box = args.return_word_box def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape @@ -415,11 +417,12 @@ class TextRecognizer(object): valid_ratios = [] imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH - # max_wh_ratio = 0 + wh_ratio_list = [] for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) + wh_ratio_list.append(wh_ratio) for ino in range(beg_img_no, end_img_no): if self.rec_algorithm == "SAR": norm_img, _, _, valid_ratio = self.resize_norm_img_sar( @@ -624,7 +627,10 @@ class TextRecognizer(object): preds = outputs else: preds = outputs[0] - rec_result = self.postprocess_op(preds) + if self.postprocess_params['name'] == 'CTCLabelDecode': + rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio) + else: + rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] if self.benchmark: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 1f9e2e1d..3ddcfda6 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -101,7 +101,7 @@ class TextSystem(object): rec_res) filter_boxes, filter_rec_res = [], [] for box, rec_result in zip(dt_boxes, rec_res): - text, score = rec_result + text, score = rec_result[0], rec_result[1] if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 40f8eab8..31630c0c 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -145,6 +145,10 @@ def init_args(): parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--use_onnx", type=str2bool, default=False) + + # extended function + parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery') + return parser -- GitLab