提交 a5280c0f 编写于 作者: T tink2123

polist seed code

上级 13961868
...@@ -75,7 +75,7 @@ Train: ...@@ -75,7 +75,7 @@ Train:
channel_first: False channel_first: False
- SEEDLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- RecResizeImg: - RecResizeImg:
character_type: en character_dict_path:
image_shape: [3, 64, 256] image_shape: [3, 64, 256]
padding: False padding: False
- KeepKeys: - KeepKeys:
...@@ -96,7 +96,7 @@ Eval: ...@@ -96,7 +96,7 @@ Eval:
channel_first: False channel_first: False
- SEEDLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- RecResizeImg: - RecResizeImg:
character_type: en character_dict_path:
image_shape: [3, 64, 256] image_shape: [3, 64, 256]
padding: False padding: False
- KeepKeys: - KeepKeys:
......
...@@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode): ...@@ -344,8 +344,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
max_text_length, character_dict_path, use_space_char) max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.padding = "padding"
self.end_str = "eos" self.end_str = "eos"
dict_character = dict_character + [self.end_str] self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding, self.unknown
]
return dict_character return dict_character
def __call__(self, data): def __call__(self, data):
...@@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode): ...@@ -356,8 +360,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
if len(text) >= self.max_text_len: if len(text) >= self.max_text_len:
return None return None
data['length'] = np.array(len(text)) + 1 # conclude eos data['length'] = np.array(len(text)) + 1 # conclude eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text) text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
) self.max_text_len - len(text) - 1)
data['label'] = np.array(text) data['label'] = np.array(text)
return data return data
......
...@@ -47,7 +47,7 @@ class AsterHead(nn.Layer): ...@@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
self.time_step = time_step self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels) self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width self.beam_width = beam_width
self.eos = self.num_classes - 1 self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None): def forward(self, x, targets=None, embed=None):
return_dict = {} return_dict = {}
......
...@@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode): ...@@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
use_space_char) use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.padding_str = "padding"
self.end_str = "eos" self.end_str = "eos"
dict_character = dict_character + [self.end_str] self.unknown = "unknown"
dict_character = dict_character + [
self.end_str, self.padding_str, self.unknown
]
return dict_character return dict_character
def get_ignored_tokens(self): def get_ignored_tokens(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册