未验证 提交 ca29bb50 编写于 作者: A Aston Zhang 提交者: GitHub

Merge pull request #144 from astonzhang/rnn

add is_lstm flag
......@@ -285,15 +285,22 @@ print('state shape: ', state_new.shape)
```{.python .input n=15}
def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,
char_to_idx, get_inputs):
char_to_idx, get_inputs, is_lstm=False):
# 预测以 prefix 开始的接下来的 num_chars 个字符。
prefix = prefix.lower()
state = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
if is_lstm:
# 当RNN使用LSTM时才会用到,这里可以忽略。
state_c = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
output = [char_to_idx[prefix[0]]]
for i in range(num_chars + len(prefix)):
X = nd.array([output[-1]], ctx=ctx)
# 在序列中循环迭代隐含变量。
Y, state = rnn(get_inputs(X), state, *params)
if is_lstm:
# 当RNN使用LSTM时才会用到,这里可以忽略。
Y, state_h, state_c = rnn(get_inputs(X), state_h, state_c, *params)
else:
Y, state_h = rnn(get_inputs(X), state_h, *params)
if i < len(prefix)-1:
next_input = char_to_idx[prefix[i+1]]
else:
......@@ -358,7 +365,8 @@ from math import exp
def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
learning_rate, clipping_theta, batch_size,
pred_period, pred_len, seqs, get_params, get_inputs,
ctx, corpus_indices, idx_to_char, char_to_idx):
ctx, corpus_indices, idx_to_char, char_to_idx,
is_lstm=False):
if is_random_iter:
data_iter = data_iter_random
else:
......@@ -370,16 +378,27 @@ def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
for e in range(1, epochs + 1):
# 如使用相邻批量采样,在同一个epoch中,隐含变量只需要在该epoch开始的时候初始化。
if not is_random_iter:
state = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
if is_lstm:
# 当RNN使用LSTM时才会用到,这里可以忽略。
state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
train_loss, num_examples = 0, 0
for data, label in data_iter(corpus_indices, batch_size, num_steps,
ctx):
# 如使用随机批量采样,处理每个随机小批量前都需要初始化隐含变量。
if is_random_iter:
state = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
if is_lstm:
# 当RNN使用LSTM时才会用到,这里可以忽略。
state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
with autograd.record():
# outputs 尺寸:(batch_size, vocab_size)
outputs, state = rnn(get_inputs(data), state, *params)
if is_lstm:
# 当RNN使用LSTM时才会用到,这里可以忽略。
outputs, state_h, state_c = rnn(get_inputs(data), state_h,
state_c, *params)
else:
outputs, state_h = rnn(get_inputs(data), state_h, *params)
# 设t_ib_j为i时间批量中的j元素:
# label 尺寸:(batch_size * num_steps)
# label = [t_0b_0, t_0b_1, ..., t_1b_0, t_1b_1, ..., ]
......@@ -401,7 +420,8 @@ def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
exp(train_loss/num_examples)))
for seq in seqs:
print(' - ', predict_rnn(rnn, seq, pred_len, params,
hidden_dim, ctx, idx_to_char, char_to_idx, get_inputs))
hidden_dim, ctx, idx_to_char, char_to_idx, get_inputs,
is_lstm))
print()
```
......
......@@ -238,6 +238,7 @@ def data_iter_consecutive(corpus_indices, batch_size, num_steps, ctx=None):
label = indices[:, i + 1: i + num_steps + 1]
yield data, label
def grad_clipping(params, theta, ctx):
"""Gradient clipping."""
if theta is not None:
......@@ -249,15 +250,21 @@ def grad_clipping(params, theta, ctx):
for p in params:
p.grad[:] *= theta / norm
def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,
char_to_idx, get_inputs):
char_to_idx, get_inputs, is_lstm=False):
"""Predict the next chars given the prefix."""
prefix = prefix.lower()
state = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
if is_lstm:
state_c = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
output = [char_to_idx[prefix[0]]]
for i in range(num_chars + len(prefix)):
X = nd.array([output[-1]], ctx=ctx)
Y, state = rnn(get_inputs(X), state, *params)
if is_lstm:
Y, state_h, state_c = rnn(get_inputs(X), state_h, state_c, *params)
else:
Y, state_h = rnn(get_inputs(X), state_h, *params)
if i < len(prefix)-1:
next_input = char_to_idx[prefix[i+1]]
else:
......@@ -265,10 +272,12 @@ def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,
output.append(next_input)
return ''.join([idx_to_char[i] for i in output])
def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
learning_rate, clipping_theta, batch_size,
pred_period, pred_len, seqs, get_params, get_inputs,
ctx, corpus_indices, idx_to_char, char_to_idx):
ctx, corpus_indices, idx_to_char, char_to_idx,
is_lstm=False):
"""Train an RNN model and predict the next item in the sequence."""
if is_random_iter:
data_iter = data_iter_random
......@@ -278,21 +287,29 @@ def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
for e in range(1, epochs + 1):
for e in range(1, epochs + 1):
# If consecutive sampling is used, in the same epoch, the hidden state
# is initialized only at the beginning of the epoch.
if not is_random_iter:
state = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
if is_lstm:
state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
train_loss, num_examples = 0, 0
for data, label in data_iter(corpus_indices, batch_size, num_steps,
ctx):
# If random sampling is used, the hidden state has to be
# initialized for each mini-batch.
if is_random_iter:
state = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
if is_lstm:
state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
with autograd.record():
# outputs shape:(batch_size, vocab_size)
outputs, state = rnn(get_inputs(data), state, *params)
if is_lstm:
outputs, state_h, state_c = rnn(get_inputs(data), state_h,
state_c, *params)
else:
outputs, state_h = rnn(get_inputs(data), state_h, *params)
# Let t_ib_j be the j-th element of the mini-batch at time i.
# label shape:(batch_size * num_steps)
# label = [t_0b_0, t_0b_1, ..., t_1b_0, t_1b_1, ..., ].
......@@ -315,6 +332,6 @@ def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
exp(train_loss/num_examples)))
for seq in seqs:
print(' - ', predict_rnn(rnn, seq, pred_len, params,
hidden_dim, ctx, idx_to_char, char_to_idx, get_inputs))
hidden_dim, ctx, idx_to_char, char_to_idx, get_inputs,
is_lstm))
print()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册