提交 b840f462 编写于 作者: E Egor Smirnov

add lstm tests

上级 b2e7c8ac
......@@ -859,6 +859,39 @@ hidden_lstm = HiddenLSTM(features, hidden, num_layers=3, is_bidirectional=True)
save_data_and_model("hidden_lstm_bi", input, hidden_lstm, version=11, export_params=True)
batch = 5
features = 4
hidden = 3
seq_len = 2
num_layers=1
bidirectional=True
class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(features, hidden, num_layers, bidirectional=bidirectional)
self.h0 = torch.from_numpy(np.ones((num_layers + int(bidirectional), batch, hidden), dtype=np.float32))
self.c0 = torch.from_numpy(np.ones((num_layers + int(bidirectional), batch, hidden), dtype=np.float32))
def forward(self, x):
a, (b, c) = self.lstm(x, (self.h0, self.c0))
if bidirectional:
return torch.cat((a, b, c), dim=2)
else:
return torch.cat((a, b, c), dim=0)
input_ = Variable(torch.randn(seq_len, batch, features))
lstm = LSTM()
save_data_and_model("lstm_cell_bidirectional", input_, lstm, export_params=True)
bidirectional = False
input_ = Variable(torch.randn(seq_len, batch, features))
lstm = LSTM()
save_data_and_model("lstm_cell_forward", input_, lstm, export_params=True)
class MatMul(nn.Module):
def __init__(self):
super(MatMul, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册