diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index 0f599258d46e2ce89a6b7deccf8287a2ec0f7e4e..0bb90b35264b424c58a45685f5a2a066843298a6 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -75,7 +75,7 @@ Train: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: @@ -96,7 +96,7 @@ Eval: channel_first: False - SEEDLabelEncode: # Class handling label - RecResizeImg: - character_type: en + character_dict_path: image_shape: [3, 64, 256] padding: False - KeepKeys: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 58d06fd4f28c9b1a68e37ae5b316a45896d4cdc5..fc14fdbcf13a61b591d9ea6c2535aefe6e437ec6 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -507,8 +507,12 @@ class SEEDLabelEncode(BaseRecLabelEncode): max_text_length, character_dict_path, use_space_char) def add_special_char(self, dict_character): + self.padding = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding, self.unknown + ] return dict_character def __call__(self, data): @@ -519,8 +523,8 @@ class SEEDLabelEncode(BaseRecLabelEncode): if len(text) >= self.max_text_len: return None data['length'] = np.array(len(text)) + 1 # conclude eos - text = text + [len(self.character) - 1] * (self.max_text_len - len(text) - ) + text = text + [len(self.character) - 3] + [len(self.character) - 2] * ( + self.max_text_len - len(text) - 1) data['label'] = np.array(text) return data diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py index 9240f002d3a8bcbde517142be6b45559430de610..c95e8fd31f84c26cf58f7fbbdaab6c825b10eea8 100644 --- a/ppocr/modeling/heads/rec_aster_head.py +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -47,7 +47,7 @@ class AsterHead(nn.Layer): self.time_step = time_step self.embeder = Embedding(self.time_step, in_channels) self.beam_width = beam_width - self.eos = self.num_classes - 1 + self.eos = self.num_classes - 3 def forward(self, x, targets=None, embed=None): return_dict = {} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a..caaa2948522cb6ea7ed74b8ab79a3d0b465059a3 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode): use_space_char) def add_special_char(self, dict_character): - self.beg_str = "sos" + self.padding_str = "padding" self.end_str = "eos" - dict_character = dict_character + [self.end_str] + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding_str, self.unknown + ] return dict_character def get_ignored_tokens(self):