提交 a094d277 编写于 作者: L LDOUBLEV

opt rec_att_head

上级 0d89f3f9
...@@ -64,8 +64,10 @@ class AttentionHead(nn.Layer): ...@@ -64,8 +64,10 @@ class AttentionHead(nn.Layer):
(outputs, hidden), alpha = self.attention_cell(hidden, inputs, (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots) char_onehots)
probs_step = self.generator(outputs) probs_step = self.generator(outputs)
probs = paddle.unsqueeze( if probs is None:
probs_step, axis=1) if probs is None else paddle.concat( probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze( [probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1) next_input = probs_step.argmax(axis=1)
...@@ -152,8 +154,10 @@ class AttentionLSTM(nn.Layer): ...@@ -152,8 +154,10 @@ class AttentionLSTM(nn.Layer):
char_onehots) char_onehots)
probs_step = self.generator(hidden[0]) probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1]) hidden = (hidden[1][0], hidden[1][1])
probs = paddle.unsqueeze( if probs is None:
probs_step, axis=1) if probs is None else paddle.concat( probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze( [probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1) probs_step, axis=1)], axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册