提交 417753e9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!608 Fix LSTM output size

Merge pull request !608 from zjun/fix_lstm_output
......@@ -149,7 +149,7 @@ class LSTM(Cell):
if self.batch_first:
x = self.transpose1(x, (1, 0, 2))
h0, c0 = hx
output, hn, cn, _ = self.lstm(x, h0, c0, self.weight)
output, hn, cn, _, _ = self.lstm(x, h0, c0, self.weight)
if self.batch_first:
output = self.transpose2(output, (1, 0, 2))
return (output, (hn, cn))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册