diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index 0f599258d46e2ce89a6b7deccf8287a2ec0f7e4e..0bb90b35264b424c58a45685f5a2a066843298a6 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -75,7 +75,7 @@ Train: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: @@ -96,7 +96,7 @@ Eval: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 0a4fad621a9038e71a9d43eb4e12f78e7e92d73d..79fffbd5e4921a1303ca39c93ea16464eea279a5 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode): max_text_length, character_dict_path, use_space_char) def add_special_char(self, dict_character): + self.padding = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding, self.unknown + ] return dict_character def __call__(self, data): @@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode): if len(text) >= self.max_text_len: return None data['length'] = np.array(len(text)) + 1 # conclude eos - text = text + [len(self.character) - 1] * (self.max_text_len - len(text) - ) + text = text + [len(self.character) - 3] + [len(self.character) - 2] * ( + self.max_text_len - len(text) - 1) data['label'] = np.array(text) return data diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py index 9240f002d3a8bcbde517142be6b45559430de610..c95e8fd31f84c26cf58f7fbbdaab6c825b10eea8 100644 --- a/ppocr/modeling/heads/rec_aster_head.py +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -47,7 +47,7 @@ class AsterHead(nn.Layer): self.time_step = time_step self.embeder = Embedding(self.time_step, in_channels) self.beam_width = beam_width - self.eos = self.num_classes - 1 + self.eos = self.num_classes - 3 def forward(self, x, targets=None, embed=None): return_dict = {} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a..caaa2948522cb6ea7ed74b8ab79a3d0b465059a3 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode): use_space_char) def add_special_char(self, dict_character): - self.beg_str = "sos" + self.padding_str = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding_str, self.unknown + ] return dict_character def get_ignored_tokens(self):