未验证 提交 b3f9f681 编写于 作者: T ToddBear 提交者: GitHub

CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515) (#10537)

* modification of return word box

* update_implements

* Update rec_postprocess.py

* Update utility.py
上级 bf6ff0b6
...@@ -67,7 +67,66 @@ class BaseRecLabelDecode(object): ...@@ -67,7 +67,66 @@ class BaseRecLabelDecode(object):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
return dict_character return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection==True)[0]
for c_i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
c_state = 'cn'
elif bool(re.search('[a-zA-Z0-9]', char)):
c_state = 'en&num'
else:
c_state = 'splitter'
if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
c_state = 'en&num'
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
c_state = 'en&num'
if state == None:
state = c_state
if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state
if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
return word_list, word_col_list, state_list
def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
result_list = [] result_list = []
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
...@@ -96,6 +155,10 @@ class BaseRecLabelDecode(object): ...@@ -96,6 +155,10 @@ class BaseRecLabelDecode(object):
if self.reverse: # for arabic rec if self.reverse: # for arabic rec
text = self.pred_reverse(text) text = self.pred_reverse(text)
if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
else:
result_list.append((text, np.mean(conf_list).tolist())) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
...@@ -111,14 +174,19 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -111,14 +174,19 @@ class CTCLabelDecode(BaseRecLabelDecode):
super(CTCLabelDecode, self).__init__(character_dict_path, super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char) use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list): if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs['wh_ratio_list'][rec_idx]
max_wh_ratio = kwargs['max_wh_ratio']
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label)
......
...@@ -34,7 +34,7 @@ from ppocr.utils.visual import draw_ser_results, draw_re_results ...@@ -34,7 +34,7 @@ from ppocr.utils.visual import draw_ser_results, draw_re_results
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box
logger = get_logger() logger = get_logger()
...@@ -79,6 +79,8 @@ class StructureSystem(object): ...@@ -79,6 +79,8 @@ class StructureSystem(object):
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args) self.kie_predictor = SerRePredictor(args)
self.return_word_box = args.return_word_box
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = { time_dict = {
'image_orientation': 0, 'image_orientation': 0,
...@@ -156,12 +158,22 @@ class StructureSystem(object): ...@@ -156,12 +158,22 @@ class StructureSystem(object):
] ]
res = [] res = []
for box, rec_res in zip(filter_boxes, filter_rec_res): for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token: for token in style_token:
if token in rec_str: if token in rec_str:
rec_str = rec_str.replace(token, '') rec_str = rec_str.replace(token, '')
if not self.recovery: if not self.recovery:
box += [x1, y1] box += [x1, y1]
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({ res.append({
'text': rec_str, 'text': rec_str,
'confidence': float(rec_conf), 'confidence': float(rec_conf),
......
...@@ -15,8 +15,13 @@ import random ...@@ -15,8 +15,13 @@ import random
import ast import ast
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
<<<<<<< HEAD
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
=======
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args
import math
>>>>>>> 1e11f254 (CV套件建设专项活动 - 文字识别返回单字识别坐标 (#10515))
def init_args(): def init_args():
parser = infer_args() parser = infer_args()
...@@ -152,6 +157,63 @@ def draw_structure_result(image, result, font_path): ...@@ -152,6 +157,63 @@ def draw_structure_result(image, result, font_path):
txts.append(text_result['text']) txts.append(text_result['text'])
scores.append(text_result['confidence']) scores.append(text_result['confidence'])
if 'text_word_region' in text_result:
for word_region in text_result['text_word_region']:
char_box = word_region
box_height = int(
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
box_width = int(
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
txts.append("")
scores.append(1.0)
im_show = draw_ocr_box_txt( im_show = draw_ocr_box_txt(
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0) img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show return im_show
def cal_ocr_word_box(rec_str, box, rec_word_info):
''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''
col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]
cell_width = (bbox_x_end - bbox_x_start)/col_num
word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == 'cn':
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length/(len(word_col)-1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx+0.5)*cell_width
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
return word_box_content_list, word_box_list
\ No newline at end of file
...@@ -123,6 +123,7 @@ class TextRecognizer(object): ...@@ -123,6 +123,7 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark self.benchmark = args.benchmark
...@@ -146,6 +147,7 @@ class TextRecognizer(object): ...@@ -146,6 +147,7 @@ class TextRecognizer(object):
], ],
warmup=0, warmup=0,
logger=logger) logger=logger)
self.return_word_box = args.return_word_box
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
...@@ -415,11 +417,12 @@ class TextRecognizer(object): ...@@ -415,11 +417,12 @@ class TextRecognizer(object):
valid_ratios = [] valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3] imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH max_wh_ratio = imgW / imgH
# max_wh_ratio = 0 wh_ratio_list = []
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2] h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR": if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar( norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
...@@ -624,6 +627,9 @@ class TextRecognizer(object): ...@@ -624,6 +627,9 @@ class TextRecognizer(object):
preds = outputs preds = outputs
else: else:
preds = outputs[0] preds = outputs[0]
if self.postprocess_params['name'] == 'CTCLabelDecode':
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
else:
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
...@@ -101,7 +101,7 @@ class TextSystem(object): ...@@ -101,7 +101,7 @@ class TextSystem(object):
rec_res) rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res): for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result text, score = rec_result[0], rec_result[1]
if score >= self.drop_score: if score >= self.drop_score:
filter_boxes.append(box) filter_boxes.append(box)
filter_rec_res.append(rec_result) filter_rec_res.append(rec_result)
......
...@@ -145,6 +145,10 @@ def init_args(): ...@@ -145,6 +145,10 @@ def init_args():
parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False) parser.add_argument("--use_onnx", type=str2bool, default=False)
# extended function
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')
return parser return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册