diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 94e0dd2264627589c6ba48294bcbf22a901952f2..36bc29793a67e38583b9eb0fb342c1595629540f 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode): dict_character = [''] + dict_character return dict_character -class SPINAttnLabelEncode(BaseRecLabelEncode): +class SPINAttnLabelEncode(AttnLabelEncode): """ Convert between text-label and text-index """ def __init__(self, diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py index 8f92d1ef45e87dd711847977a35ca7006ae52d49..86e35e4339d8e1006cfe43d6cf4f2f7d231082c4 100644 --- a/ppocr/modeling/heads/rec_spin_att_head.py +++ b/ppocr/modeling/heads/rec_spin_att_head.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 26ea71fd0d1d072c4febf3d3ae26935f611e9a69..6f64899b7fbda527b24e1f7891d89eae19817c11 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -669,7 +669,86 @@ class ABINetLabelDecode(NRTRLabelDecode): return dict_character -class SPINAttnLabelDecode(BaseRecLabelDecode): +# class SPINAttnLabelDecode(BaseRecLabelDecode): +# """ Convert between text-label and text-index """ + +# def __init__(self, character_dict_path=None, use_space_char=False, +# **kwargs): +# super(SPINAttnLabelDecode, self).__init__(character_dict_path, +# use_space_char) + +# def add_special_char(self, dict_character): +# self.beg_str = "sos" +# self.end_str = "eos" +# dict_character = dict_character +# dict_character = [self.beg_str] + [self.end_str] + dict_character +# return dict_character + +# def decode(self, text_index, text_prob=None, is_remove_duplicate=False): +# """ convert text-index into text-label. """ +# result_list = [] +# ignored_tokens = self.get_ignored_tokens() +# [beg_idx, end_idx] = self.get_ignored_tokens() +# batch_size = len(text_index) +# for batch_idx in range(batch_size): +# char_list = [] +# conf_list = [] +# for idx in range(len(text_index[batch_idx])): +# if text_index[batch_idx][idx] == int(beg_idx): +# continue +# if int(text_index[batch_idx][idx]) == int(end_idx): +# break +# if is_remove_duplicate: +# # only for predict +# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ +# batch_idx][idx]: +# continue +# char_list.append(self.character[int(text_index[batch_idx][ +# idx])]) +# if text_prob is not None: +# conf_list.append(text_prob[batch_idx][idx]) +# else: +# conf_list.append(1) +# text = ''.join(char_list) +# result_list.append((text.lower(), np.mean(conf_list).tolist())) +# return result_list + +# def __call__(self, preds, label=None, *args, **kwargs): +# """ +# text = self.decode(text) +# if label is None: +# return text +# else: +# label = self.decode(label, is_remove_duplicate=False) +# return text, label +# """ +# if isinstance(preds, paddle.Tensor): +# preds = preds.numpy() + +# preds_idx = preds.argmax(axis=2) +# preds_prob = preds.max(axis=2) +# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) +# if label is None: +# return text +# label = self.decode(label, is_remove_duplicate=False) +# return text, label + +# def get_ignored_tokens(self): +# beg_idx = self.get_beg_end_flag_idx("beg") +# end_idx = self.get_beg_end_flag_idx("end") +# return [beg_idx, end_idx] + +# def get_beg_end_flag_idx(self, beg_or_end): +# if beg_or_end == "beg": +# idx = np.array(self.dict[self.beg_str]) +# elif beg_or_end == "end": +# idx = np.array(self.dict[self.end_str]) +# else: +# assert False, "unsupport type %s in get_beg_end_flag_idx" \ +# % beg_or_end +# return idx + +class SPINAttnLabelDecode(AttnLabelDecode): """ Convert between text-label and text-index """ def __init__(self, character_dict_path=None, use_space_char=False, @@ -682,68 +761,4 @@ class SPINAttnLabelDecode(BaseRecLabelDecode): self.end_str = "eos" dict_character = dict_character dict_character = [self.beg_str] + [self.end_str] + dict_character - return dict_character - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - ignored_tokens = self.get_ignored_tokens() - [beg_idx, end_idx] = self.get_ignored_tokens() - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == int(beg_idx): - continue - if int(text_index[batch_idx][idx]) == int(end_idx): - break - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - result_list.append((text.lower(), np.mean(conf_list).tolist())) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - """ - text = self.decode(text) - if label is None: - return text - else: - label = self.decode(label, is_remove_duplicate=False) - return text, label - """ - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label, is_remove_duplicate=False) - return text, label - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return idx \ No newline at end of file + return dict_character \ No newline at end of file diff --git a/tools/export_model.py b/tools/export_model.py index 4855c53a978706c52feaebeb7b3649a71bd66b8e..69ac904c661fad77255c70563fdf1f16c5c29875 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": + elif arch_config["algorithm"] in ["NRTR", "SPIN"]: other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"),