提交 44826b51 编写于 作者: L LDOUBLEV

cherry-pick 3505

上级 6887f457
...@@ -19,6 +19,7 @@ from __future__ import unicode_literals ...@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import numpy as np import numpy as np
import string import string
import json
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -39,7 +40,6 @@ class DetLabelEncode(object): ...@@ -39,7 +40,6 @@ class DetLabelEncode(object):
pass pass
def __call__(self, data): def __call__(self, data):
import json
label = data['label'] label = data['label']
label = json.loads(label) label = json.loads(label)
nBox = len(label) nBox = len(label)
...@@ -54,8 +54,8 @@ class DetLabelEncode(object): ...@@ -54,8 +54,8 @@ class DetLabelEncode(object):
else: else:
txt_tags.append(False) txt_tags.append(False)
boxes = self.expand_points_num(boxes) boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32) #boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool) #txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes data['polys'] = boxes
data['texts'] = txts data['texts'] = txts
...@@ -352,19 +352,22 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -352,19 +352,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
% beg_or_end % beg_or_end
return idx return idx
class TableLabelEncode(object): class TableLabelEncode(object):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self,
max_text_length, def __init__(self,
max_elem_length, max_text_length,
max_cell_num, max_elem_length,
character_dict_path, max_cell_num,
span_weight = 1.0, character_dict_path,
**kwargs): span_weight=1.0,
**kwargs):
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.max_elem_length = max_elem_length self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path) list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character) list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) list_elem = self.add_special_char(list_elem)
self.dict_character = {} self.dict_character = {}
...@@ -374,7 +377,7 @@ class TableLabelEncode(object): ...@@ -374,7 +377,7 @@ class TableLabelEncode(object):
for i, elem in enumerate(list_elem): for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i self.dict_elem[elem] = i
self.span_weight = span_weight self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path): def load_char_elem_dict(self, character_dict_path):
list_character = [] list_character = []
list_elem = [] list_elem = []
...@@ -383,27 +386,27 @@ class TableLabelEncode(object): ...@@ -383,27 +386,27 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\n").split("\t") substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0]) character_num = int(substr[0])
elem_num = int(substr[1]) elem_num = int(substr[1])
for cno in range(1, 1+character_num): for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n") character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character) list_character.append(character)
for eno in range(1+character_num, 1+character_num+elem_num): for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n") elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem) list_elem.append(elem)
return list_character, list_elem return list_character, list_elem
def add_special_char(self, list_character): def add_special_char(self, list_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str] list_character = [self.beg_str] + list_character + [self.end_str]
return list_character return list_character
def get_span_idx_list(self): def get_span_idx_list(self):
span_idx_list = [] span_idx_list = []
for elem in self.dict_elem: for elem in self.dict_elem:
if 'span' in elem: if 'span' in elem:
span_idx_list.append(self.dict_elem[elem]) span_idx_list.append(self.dict_elem[elem])
return span_idx_list return span_idx_list
def __call__(self, data): def __call__(self, data):
cells = data['cells'] cells = data['cells']
structure = data['structure']['tokens'] structure = data['structure']['tokens']
...@@ -412,18 +415,22 @@ class TableLabelEncode(object): ...@@ -412,18 +415,22 @@ class TableLabelEncode(object):
return None return None
elem_num = len(structure) elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1] structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
)
structure = np.array(structure) structure = np.array(structure)
data['structure'] = structure data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>'] elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td'] elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list() span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2) td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0] td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32) structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32) bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32) bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0: if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
...@@ -450,9 +457,11 @@ class TableLabelEncode(object): ...@@ -450,9 +457,11 @@ class TableLabelEncode(object):
char_end_idx = self.get_beg_end_flag_idx('end', 'char') char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx, data['sp_tokens'] = np.array([
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
self.max_elem_length, self.max_cell_num, elem_num]) elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data return data
def encode(self, text, char_or_elem): def encode(self, text, char_or_elem):
...@@ -504,9 +513,8 @@ class TableLabelEncode(object): ...@@ -504,9 +513,8 @@ class TableLabelEncode(object):
idx = np.array(self.dict_elem[self.end_str]) idx = np.array(self.dict_elem[self.end_str])
else: else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end % beg_or_end
else: else:
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
\ No newline at end of file
...@@ -24,6 +24,7 @@ from paddle import inference ...@@ -24,6 +24,7 @@ from paddle import inference
import time import time
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
...@@ -47,8 +48,8 @@ def init_args(): ...@@ -47,8 +48,8 @@ def init_args():
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast") parser.add_argument("--det_db_score_mode", type=str, default="fast")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册