From a5280c0f4026e16bd8be491221d0db8a2ec323d7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 17 Dec 2021 21:42:53 +0800 Subject: [PATCH] polist seed code --- configs/rec/rec_resnet_stn_bilstm_att.yml | 4 ++-- ppocr/data/imaug/label_ops.py | 10 +++++++--- ppocr/modeling/heads/rec_aster_head.py | 2 +- ppocr/postprocess/rec_postprocess.py | 7 +++++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index 0f599258..0bb90b35 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -75,7 +75,7 @@ Train: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: @@ -96,7 +96,7 @@ Eval: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 0a4fad62..79fffbd5 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode): max_text_length, character_dict_path, use_space_char) def add_special_char(self, dict_character): + self.padding = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding, self.unknown + ] return dict_character def __call__(self, data): @@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode): if len(text) >= self.max_text_len: return None data['length'] = np.array(len(text)) + 1 # conclude eos - text = text + [len(self.character) - 1] * (self.max_text_len - len(text) - ) + text = text + [len(self.character) - 3] + [len(self.character) - 2] * ( + self.max_text_len - len(text) - 1) data['label'] = np.array(text) return data diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py index 9240f002..c95e8fd3 100644 --- a/ppocr/modeling/heads/rec_aster_head.py +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -47,7 +47,7 @@ class AsterHead(nn.Layer): self.time_step = time_step self.embeder = Embedding(self.time_step, in_channels) self.beam_width = beam_width - self.eos = self.num_classes - 1 + self.eos = self.num_classes - 3 def forward(self, x, targets=None, embed=None): return_dict = {} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index ef1a43fd..caaa2948 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode): use_space_char) def add_special_char(self, dict_character): - self.beg_str = "sos" + self.padding_str = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding_str, self.unknown + ] return dict_character def get_ignored_tokens(self): -- GitLab