未验证 提交 683c86ed 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #4960 from tink2123/fix_seed_padding

polist seed code
......@@ -75,7 +75,7 @@ Train:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
character_type: en
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
......@@ -96,7 +96,7 @@ Eval:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
character_type: en
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
......
......@@ -507,8 +507,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.padding = "padding"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding, self.unknown
]
return dict_character
def __call__(self, data):
......@@ -519,8 +523,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclude eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
self.max_text_len - len(text) - 1)
data['label'] = np.array(text)
return data
......
......@@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
self.eos = self.num_classes - 1
self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None):
return_dict = {}
......
......@@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.padding_str = "padding"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding_str, self.unknown
]
return dict_character
def get_ignored_tokens(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册