From f614274672a0c0874c4749f51be0520babf170ca Mon Sep 17 00:00:00 2001 From: smilelite Date: Tue, 12 Jul 2022 22:15:00 +0800 Subject: [PATCH] modified label_ops --- ppocr/data/imaug/label_ops.py | 17 +----- ppocr/postprocess/rec_postprocess.py | 80 ---------------------------- 2 files changed, 1 insertion(+), 96 deletions(-) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 36bc2979..775ceec8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1248,19 +1248,4 @@ class SPINAttnLabelEncode(AttnLabelEncode): padded_text[:len(target)] = target data['label'] = np.array(padded_text) - return data - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return \ No newline at end of file + return data \ No newline at end of file diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 6f64899b..3e7c29d8 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -668,86 +668,6 @@ class ABINetLabelDecode(NRTRLabelDecode): dict_character = [''] + dict_character return dict_character - -# class SPINAttnLabelDecode(BaseRecLabelDecode): -# """ Convert between text-label and text-index """ - -# def __init__(self, character_dict_path=None, use_space_char=False, -# **kwargs): -# super(SPINAttnLabelDecode, self).__init__(character_dict_path, -# use_space_char) - -# def add_special_char(self, dict_character): -# self.beg_str = "sos" -# self.end_str = "eos" -# dict_character = dict_character -# dict_character = [self.beg_str] + [self.end_str] + dict_character -# return dict_character - -# def decode(self, text_index, text_prob=None, is_remove_duplicate=False): -# """ convert text-index into text-label. """ -# result_list = [] -# ignored_tokens = self.get_ignored_tokens() -# [beg_idx, end_idx] = self.get_ignored_tokens() -# batch_size = len(text_index) -# for batch_idx in range(batch_size): -# char_list = [] -# conf_list = [] -# for idx in range(len(text_index[batch_idx])): -# if text_index[batch_idx][idx] == int(beg_idx): -# continue -# if int(text_index[batch_idx][idx]) == int(end_idx): -# break -# if is_remove_duplicate: -# # only for predict -# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ -# batch_idx][idx]: -# continue -# char_list.append(self.character[int(text_index[batch_idx][ -# idx])]) -# if text_prob is not None: -# conf_list.append(text_prob[batch_idx][idx]) -# else: -# conf_list.append(1) -# text = ''.join(char_list) -# result_list.append((text.lower(), np.mean(conf_list).tolist())) -# return result_list - -# def __call__(self, preds, label=None, *args, **kwargs): -# """ -# text = self.decode(text) -# if label is None: -# return text -# else: -# label = self.decode(label, is_remove_duplicate=False) -# return text, label -# """ -# if isinstance(preds, paddle.Tensor): -# preds = preds.numpy() - -# preds_idx = preds.argmax(axis=2) -# preds_prob = preds.max(axis=2) -# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) -# if label is None: -# return text -# label = self.decode(label, is_remove_duplicate=False) -# return text, label - -# def get_ignored_tokens(self): -# beg_idx = self.get_beg_end_flag_idx("beg") -# end_idx = self.get_beg_end_flag_idx("end") -# return [beg_idx, end_idx] - -# def get_beg_end_flag_idx(self, beg_or_end): -# if beg_or_end == "beg": -# idx = np.array(self.dict[self.beg_str]) -# elif beg_or_end == "end": -# idx = np.array(self.dict[self.end_str]) -# else: -# assert False, "unsupport type %s in get_beg_end_flag_idx" \ -# % beg_or_end -# return idx - class SPINAttnLabelDecode(AttnLabelDecode): """ Convert between text-label and text-index """ -- GitLab