提交 01b7b11f 编写于 作者: S ShusenTang

fix bug

ls
上级 4a087ef2
......@@ -497,7 +497,7 @@ def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
state = None
for epoch in range(num_epochs):
l_sum, n, start = 0.0, 0, time.time()
data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样
data_iter = data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样
for X, Y in data_iter:
if state is not None:
# 使用detach函数从计算图分离隐藏状态, 这是为了
......@@ -517,7 +517,7 @@ def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
optimizer.zero_grad()
l.backward()
# 梯度裁剪
d2l.grad_clipping(model.parameters(), clipping_theta, device)
grad_clipping(model.parameters(), clipping_theta, device)
optimizer.step()
l_sum += l.item() * y.shape[0]
n += y.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册