diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 32abe94deffa41c6f51ef62f011d0e4d58811065..d222c4109c3723bc1adb71ee7c21a27a010f8f45 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals import numpy as np import string +import json class ClsLabelEncode(object): @@ -39,7 +40,6 @@ class DetLabelEncode(object): pass def __call__(self, data): - import json label = data['label'] label = json.loads(label) nBox = len(label) @@ -53,6 +53,8 @@ class DetLabelEncode(object): txt_tags.append(True) else: txt_tags.append(False) + if len(boxes) == 0: + return None boxes = self.expand_points_num(boxes) boxes = np.array(boxes, dtype=np.float32) txt_tags = np.array(txt_tags, dtype=np.bool) @@ -352,19 +354,22 @@ class SRNLabelEncode(BaseRecLabelEncode): % beg_or_end return idx + class TableLabelEncode(object): """ Convert between text-label and text-index """ - def __init__(self, - max_text_length, - max_elem_length, - max_cell_num, - character_dict_path, - span_weight = 1.0, - **kwargs): + + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight=1.0, + **kwargs): self.max_text_length = max_text_length self.max_elem_length = max_elem_length self.max_cell_num = max_cell_num - list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) self.dict_character = {} @@ -374,7 +379,7 @@ class TableLabelEncode(object): for i, elem in enumerate(list_elem): self.dict_elem[elem] = i self.span_weight = span_weight - + def load_char_elem_dict(self, character_dict_path): list_character = [] list_elem = [] @@ -383,27 +388,28 @@ class TableLabelEncode(object): substr = lines[0].decode('utf-8').strip("\r\n").split("\t") character_num = int(substr[0]) elem_num = int(substr[1]) - for cno in range(1, 1+character_num): + + for cno in range(1, 1 + character_num): character = lines[cno].decode('utf-8').strip("\r\n") list_character.append(character) - for eno in range(1+character_num, 1+character_num+elem_num): + for eno in range(1 + character_num, 1 + character_num + elem_num): elem = lines[eno].decode('utf-8').strip("\r\n") list_elem.append(elem) return list_character, list_elem - + def add_special_char(self, list_character): self.beg_str = "sos" self.end_str = "eos" list_character = [self.beg_str] + list_character + [self.end_str] return list_character - + def get_span_idx_list(self): span_idx_list = [] for elem in self.dict_elem: if 'span' in elem: span_idx_list.append(self.dict_elem[elem]) return span_idx_list - + def __call__(self, data): cells = data['cells'] structure = data['structure']['tokens'] @@ -412,18 +418,22 @@ class TableLabelEncode(object): return None elem_num = len(structure) structure = [0] + structure + [len(self.dict_elem) - 1] - structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) + structure = structure + [0] * (self.max_elem_length + 2 - len(structure) + ) structure = np.array(structure) data['structure'] = structure elem_char_idx1 = self.dict_elem[''] elem_char_idx2 = self.dict_elem[' 0: span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) @@ -450,9 +460,11 @@ class TableLabelEncode(object): char_end_idx = self.get_beg_end_flag_idx('end', 'char') elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') - data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx, - elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, - self.max_elem_length, self.max_cell_num, elem_num]) + data['sp_tokens'] = np.array([ + char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx, + elem_char_idx1, elem_char_idx2, self.max_text_length, + self.max_elem_length, self.max_cell_num, elem_num + ]) return data def encode(self, text, char_or_elem): @@ -504,9 +516,8 @@ class TableLabelEncode(object): idx = np.array(self.dict_elem[self.end_str]) else: assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ - % beg_or_end + % beg_or_end else: assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem + % char_or_elem return idx - \ No newline at end of file diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 28e9818ba604fbf71de356dfce23d8a02ce3d9dd..c8e8d5b3349dd40fd15fe7c6f4a14c362f931f97 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -24,6 +24,7 @@ from paddle import inference import time from ppocr.utils.logging import get_logger + def str2bool(v): return v.lower() in ("true", "t", "1") @@ -47,8 +48,8 @@ def init_args(): # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) + parser.add_argument("--det_db_box_thresh", type=float, default=0.6) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5) parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast")