提交 cf6a8dc5 编写于 作者: M muli

update lstm

上级 d538969f
......@@ -6,13 +6,12 @@
## 长短期记忆
我们先介绍长短期记忆的设计。它修改了循环神经网络隐藏状态的计算方式,并引入了与隐藏状态形状相同的记忆细胞(某些文献把记忆细胞当成一种特殊的隐藏状态),从而记录额外的历史信息。
LSTM 中引入了三个门:输入门(input gate)、遗忘门(forget gate)和输出门(output gate);以及与隐藏状态形状相同的记忆细胞(某些文献把记忆细胞当成一种特殊的隐藏状态),从而记录额外的历史信息。
### 输入门、遗忘门和输出门
同门控循环单元中的重置门和更新门一样,如图6.7所示,长短期记忆的输入门(input gate)、遗忘门(forget gate)和输出门(output gate)均由输入为当前时间步输入$\boldsymbol{X}_t$与上一时间步隐藏状态$\boldsymbol{H}_{t-1}$,且激活函数为sigmoid函数的全连接层计算得出。如此一来,这三个门元素的值域均为$[0,1]$。
同门控循环单元中的重置门和更新门一样,如图6.7所示,LSTM的门均由输入为当前时间步输入$\boldsymbol{X}_t$与上一时间步隐藏状态$\boldsymbol{H}_{t-1}$,且激活函数为sigmoid函数的全连接层计算得出。如此一来,这三个门元素的值域均为$[0,1]$。
![长短期记忆中输入门、遗忘门和输出门的计算。](../img/lstm_0.svg)
......@@ -67,11 +66,6 @@ $$\boldsymbol{H}_t = \boldsymbol{O}_t \odot \text{tanh}(\boldsymbol{C}_t).$$
![长短期记忆中隐藏状态的计算。这里的乘号是按元素乘法。](../img/lstm_3.svg)
### 输出层
在时间步$t$,长短期记忆的输出层计算和之前描述的循环神经网络输出层计算一样:我们只需将该时刻的隐藏状态$\boldsymbol{H}_t$传递进输出层,从而计算时间步$t$的输出。
## 实验
和前几节中的实验一样,我们依然使用周杰伦歌词数据集来训练模型作词。
......@@ -88,11 +82,9 @@ import gluonbook as gb
from mxnet import nd
import zipfile
with zipfile.ZipFile('../data/jaychou_lyrics.txt.zip', 'r') as zin:
zin.extractall('../data/')
with open('../data/jaychou_lyrics.txt', encoding='utf-8') as f:
corpus_chars = f.read()
with zipfile.ZipFile('../data/jaychou_lyrics.txt.zip') as zin:
with zin.open('jaychou_lyrics.txt') as f:
corpus_chars = f.read().decode('utf-8')
corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ')
corpus_chars = corpus_chars[0:20000]
idx_to_char = list(set(corpus_chars))
......@@ -106,41 +98,24 @@ vocab_size = len(char_to_idx)
以下部分对模型参数进行初始化。超参数`num_hiddens`定义了隐藏单元的个数。
```{.python .input n=3}
ctx = gb.try_gpu()
input_dim = vocab_size
num_inputs = vocab_size
num_hiddens = 256
output_dim = vocab_size
num_outputs = vocab_size
ctx = gb.try_gpu()
def get_params():
# 输入门参数.
W_xi = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
ctx=ctx)
W_hi = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
ctx=ctx)
b_i = nd.zeros(num_hiddens, ctx=ctx)
# 遗忘门参数。
W_xf = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
ctx=ctx)
W_hf = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
ctx=ctx)
b_f = nd.zeros(num_hiddens, ctx=ctx)
# 输出门参数。
W_xo = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
ctx=ctx)
W_ho = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
ctx=ctx)
b_o = nd.zeros(num_hiddens, ctx=ctx)
# 候选细胞参数。
W_xc = nd.random_normal(scale=0.01, shape=(input_dim, num_hiddens),
ctx=ctx)
W_hc = nd.random_normal(scale=0.01, shape=(num_hiddens, num_hiddens),
ctx=ctx)
b_c = nd.zeros(num_hiddens, ctx=ctx)
_one = lambda shape: nd.random.normal(scale=0.01, shape=shape, ctx=ctx)
_three = lambda : (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
nd.zeros(num_hiddens, ctx=ctx))
W_xi, W_hi, b_i = _three() # 输入门参数。
W_xf, W_hf, b_f = _three() # 遗忘门参数。
W_xo, W_ho, b_o = _three() # 输出门参数。
W_xc, W_hc, b_c = _three() # 候选细胞参数。
# 输出层参数。
W_hy = nd.random_normal(scale=0.01, shape=(num_hiddens, output_dim),
ctx=ctx)
W_hy = _one((num_hiddens, output_dim))
b_y = nd.zeros(output_dim, ctx=ctx)
# 创建梯度。
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hy, b_y]
for param in params:
......@@ -153,11 +128,9 @@ def get_params():
下面根据长短期记忆的计算表达式定义模型。
```{.python .input n=4}
def lstm_rnn(inputs, state_h, state_c, *params):
def lstm_rnn(inputs, H, C, *params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
W_hy, b_y] = params
H = state_h
C = state_c
outputs = []
for X in inputs:
I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)
......@@ -173,7 +146,7 @@ def lstm_rnn(inputs, state_h, state_c, *params):
### 训练模型并创作歌词
设置好超参数后,我们将训练模型并跟据前缀“分开”和“不分开”分别创作长度为100个字符的一段歌词。我们每过30个迭代周期便根据当前训练的模型创作一段歌词。训练模型时采用了相邻采样。
设置好超参数后,我们将训练模型并跟据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。我们每过30个迭代周期便根据当前训练的模型创作一段歌词。训练模型时采用了相邻采样。
```{.python .input n=5}
get_inputs = gb.to_onehot
......@@ -184,7 +157,7 @@ lr = 0.25
clipping_theta = 5
prefixes = ['分开', '不分开']
pred_period = 30
pred_len = 100
pred_len = 50
gb.train_and_predict_rnn(lstm_rnn, False, num_epochs, num_steps, num_hiddens,
lr, clipping_theta, batch_size, vocab_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册