提交 ae09ef60 编写于 作者: A andyjpaddle

fix code style

上级 8123688a
...@@ -166,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -166,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode):
use_space_char=True, use_space_char=True,
**kwargs): **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path, super(NRTRLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64: if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
if preds[0][0]==2: if preds[0][0] == 2:
preds_idx = preds[:,1:] preds_idx = preds[:, 1:]
else: else:
preds_idx = preds preds_idx = preds
text = self.decode(preds_idx) text = self.decode(preds_idx)
if label is None: if label is None:
return text return text
label = self.decode(label[:,1:]) label = self.decode(label[:, 1:])
else: else:
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
...@@ -189,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -189,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None: if label is None:
return text return text
label = self.decode(label[:,1:]) label = self.decode(label[:, 1:])
return text, label return text, label
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
result_list = [] result_list = []
...@@ -204,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -204,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode):
char_list = [] char_list = []
conf_list = [] conf_list = []
for idx in range(len(text_index[batch_idx])): for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end if text_index[batch_idx][idx] == 3: # end
break break
try: try:
char_list.append(self.character[int(text_index[batch_idx][idx])]) char_list.append(self.character[int(text_index[batch_idx][
idx])])
except: except:
continue continue
if text_prob is not None: if text_prob is not None:
...@@ -219,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -219,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
return result_list return result_list
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -257,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -257,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]: batch_idx][idx]:
continue continue
char_list.append(self.character[int(text_index[batch_idx][idx])]) char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None: if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx]) conf_list.append(text_prob[batch_idx][idx])
else: else:
...@@ -387,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -387,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object): class TableLabelDecode(object):
""" """ """ """
def __init__(self, def __init__(self, character_dict_path, **kwargs):
character_dict_path, list_character, list_elem = self.load_char_elem_dict(
**kwargs): character_dict_path)
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character) list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) list_elem = self.add_special_char(list_elem)
self.dict_character = {} self.dict_character = {}
...@@ -409,7 +409,8 @@ class TableLabelDecode(object): ...@@ -409,7 +409,8 @@ class TableLabelDecode(object):
list_elem = [] list_elem = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t") substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0]) character_num = int(substr[0])
elem_num = int(substr[1]) elem_num = int(substr[1])
for cno in range(1, 1 + character_num): for cno in range(1, 1 + character_num):
...@@ -429,14 +430,14 @@ class TableLabelDecode(object): ...@@ -429,14 +430,14 @@ class TableLabelDecode(object):
def __call__(self, preds): def __call__(self, preds):
structure_probs = preds['structure_probs'] structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds'] loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor): if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy() structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor): if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy() loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2) structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2) structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_probs, 'elem') structure_idx, structure_probs, 'elem')
res_html_code_list = [] res_html_code_list = []
res_loc_list = [] res_loc_list = []
batch_num = len(structure_str) batch_num = len(structure_str)
...@@ -451,8 +452,13 @@ class TableLabelDecode(object): ...@@ -451,8 +452,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc) res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code) res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc) res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list, return {
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str} 'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem): def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index. """convert text-label into text-index.
...@@ -528,9 +534,9 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -528,9 +534,9 @@ class SARLabelDecode(BaseRecLabelDecode):
use_space_char=False, use_space_char=False,
**kwargs): **kwargs):
super(SARLabelDecode, self).__init__(character_dict_path, super(SARLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) character_type, use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', True) self.rm_symbol = kwargs.get('rm_symbol', True)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>" beg_end_str = "<BOS/EOS>"
...@@ -549,7 +555,7 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -549,7 +555,7 @@ class SARLabelDecode(BaseRecLabelDecode):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
result_list = [] result_list = []
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index) batch_size = len(text_index)
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
char_list = [] char_list = []
...@@ -558,7 +564,7 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -558,7 +564,7 @@ class SARLabelDecode(BaseRecLabelDecode):
if text_index[batch_idx][idx] in ignored_tokens: if text_index[batch_idx][idx] in ignored_tokens:
continue continue
if int(text_index[batch_idx][idx]) == int(self.end_idx): if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx ==0: if text_prob is None and idx == 0:
continue continue
else: else:
break break
...@@ -586,7 +592,7 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -586,7 +592,7 @@ class SARLabelDecode(BaseRecLabelDecode):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None: if label is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册