提交 f26846cc 编写于 作者: T tink2123

fix attenton loss for ce

上级 1bbf6e6a
...@@ -45,6 +45,7 @@ class AttentionHead(nn.Layer): ...@@ -45,6 +45,7 @@ class AttentionHead(nn.Layer):
output_hiddens = [] output_hiddens = []
if targets is not None: if targets is not None:
print("target is not None")
for i in range(num_steps): for i in range(num_steps):
char_onehots = self._char_to_onehot( char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes) targets[:, i], onehot_dim=self.num_classes)
...@@ -53,8 +54,8 @@ class AttentionHead(nn.Layer): ...@@ -53,8 +54,8 @@ class AttentionHead(nn.Layer):
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output) probs = self.generator(output)
else: else:
print("target is None")
targets = paddle.zeros(shape=[batch_size], dtype="int32") targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None probs = None
char_onehots = None char_onehots = None
...@@ -75,6 +76,7 @@ class AttentionHead(nn.Layer): ...@@ -75,6 +76,7 @@ class AttentionHead(nn.Layer):
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1) next_input = probs_step.argmax(axis=1)
targets = next_input targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2) probs = paddle.nn.functional.softmax(probs, axis=2)
return probs return probs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册