diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 2bc606488201bec6296436ba1291213b58ab304d..39ef16ecfb14bb7a87fd03645b75b9abf095dce3 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -166,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode): use_space_char=True, **kwargs): super(NRTRLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) + character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): if preds.dtype == paddle.int64: if isinstance(preds, paddle.Tensor): preds = preds.numpy() - if preds[0][0]==2: - preds_idx = preds[:,1:] + if preds[0][0] == 2: + preds_idx = preds[:, 1:] else: preds_idx = preds text = self.decode(preds_idx) if label is None: return text - label = self.decode(label[:,1:]) + label = self.decode(label[:, 1:]) else: if isinstance(preds, paddle.Tensor): preds = preds.numpy() @@ -189,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode): text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if label is None: return text - label = self.decode(label[:,1:]) + label = self.decode(label[:, 1:]) return text, label def add_special_char(self, dict_character): - dict_character = ['blank','','',''] + dict_character + dict_character = ['blank', '', '', ''] + dict_character return dict_character - + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """ result_list = [] @@ -204,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode): char_list = [] conf_list = [] for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == 3: # end + if text_index[batch_idx][idx] == 3: # end break try: - char_list.append(self.character[int(text_index[batch_idx][idx])]) + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) except: continue if text_prob is not None: @@ -219,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode): return result_list - class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ @@ -257,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][idx])]) + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: @@ -387,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode): class TableLabelDecode(object): """ """ - def __init__(self, - character_dict_path, - **kwargs): - list_character, list_elem = self.load_char_elem_dict(character_dict_path) + def __init__(self, character_dict_path, **kwargs): + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) list_character = self.add_special_char(list_character) list_elem = self.add_special_char(list_elem) self.dict_character = {} @@ -409,7 +409,8 @@ class TableLabelDecode(object): list_elem = [] with open(character_dict_path, "rb") as fin: lines = fin.readlines() - substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t") + substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( + "\t") character_num = int(substr[0]) elem_num = int(substr[1]) for cno in range(1, 1 + character_num): @@ -429,14 +430,14 @@ class TableLabelDecode(object): def __call__(self, preds): structure_probs = preds['structure_probs'] loc_preds = preds['loc_preds'] - if isinstance(structure_probs,paddle.Tensor): + if isinstance(structure_probs, paddle.Tensor): structure_probs = structure_probs.numpy() - if isinstance(loc_preds,paddle.Tensor): + if isinstance(loc_preds, paddle.Tensor): loc_preds = loc_preds.numpy() structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) - structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, - structure_probs, 'elem') + structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( + structure_idx, structure_probs, 'elem') res_html_code_list = [] res_loc_list = [] batch_num = len(structure_str) @@ -451,8 +452,13 @@ class TableLabelDecode(object): res_loc = np.array(res_loc) res_html_code_list.append(res_html_code) res_loc_list.append(res_loc) - return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list, - 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str} + return { + 'res_html_code': res_html_code_list, + 'res_loc': res_loc_list, + 'res_score_list': result_score_list, + 'res_elem_idx_list': result_elem_idx_list, + 'structure_str_list': structure_str + } def decode(self, text_index, structure_probs, char_or_elem): """convert text-label into text-index. @@ -528,9 +534,9 @@ class SARLabelDecode(BaseRecLabelDecode): use_space_char=False, **kwargs): super(SARLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) - - self.rm_symbol = kwargs.get('rm_symbol', True) + character_type, use_space_char) + + self.rm_symbol = kwargs.get('rm_symbol', True) def add_special_char(self, dict_character): beg_end_str = "" @@ -549,7 +555,7 @@ class SARLabelDecode(BaseRecLabelDecode): """ convert text-index into text-label. """ result_list = [] ignored_tokens = self.get_ignored_tokens() - + batch_size = len(text_index) for batch_idx in range(batch_size): char_list = [] @@ -558,7 +564,7 @@ class SARLabelDecode(BaseRecLabelDecode): if text_index[batch_idx][idx] in ignored_tokens: continue if int(text_index[batch_idx][idx]) == int(self.end_idx): - if text_prob is None and idx ==0: + if text_prob is None and idx == 0: continue else: break @@ -586,7 +592,7 @@ class SARLabelDecode(BaseRecLabelDecode): preds = preds.numpy() preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if label is None: