提交 477eb586 编写于 作者: T tink2123

fix attn loss for ce

上级 8b3b7879
...@@ -318,7 +318,7 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -318,7 +318,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
text = self.encode(text) text = self.encode(text)
if text is None: if text is None:
return None return None
if len(text) >= self.max_text_len: if len(text) >= self.max_text_len - 1:
return None return None
data['length'] = np.array(len(text)) data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
......
...@@ -75,6 +75,7 @@ class AttentionHead(nn.Layer): ...@@ -75,6 +75,7 @@ class AttentionHead(nn.Layer):
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1) next_input = probs_step.argmax(axis=1)
targets = next_input targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2) probs = paddle.nn.functional.softmax(probs, axis=2)
return probs return probs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册