提交 962f608b 编写于 作者: A Aston Zhang

till grad clip

上级 2fd0f372
......@@ -24,19 +24,6 @@ import time
nd.one_hot(nd.array([0, 2]), vocab_size)
```
```{.json .output n=2}
[
{
"data": {
"text/plain": "\n[[ 1. 0. 0. ..., 0. 0. 0.]\n [ 0. 0. 1. ..., 0. 0. 0.]]\n<NDArray 2x1027 @cpu(0)>"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
]
```
我们每次采样的小批量的形状是(`batch_size`, `num_steps`)。下面的函数将这样的小批量转换成`num_steps`个可以输入进网络的形状为(`batch_size`, `vocab_size`)的矩阵。也就是总时间步$T=$ `num_steps`,时间步$t$的输入$\boldsymbol{X}_t \in \mathbb{R}^{n \times d}$,其中$n=$ `batch_size`,$d=$ `vocab_size`(one-hot向量长度)。
```{.python .input n=3}
......@@ -48,19 +35,6 @@ inputs = to_onehot(X, vocab_size)
len(inputs), inputs[0].shape
```
```{.json .output n=3}
[
{
"data": {
"text/plain": "(5, (2, 1027))"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
```
## 初始化模型参数
接下来,我们初始化模型参数。隐藏单元个数 `num_hiddens`是一个超参数。
......@@ -88,16 +62,6 @@ def get_params():
return params
```
```{.json .output n=4}
[
{
"name": "stdout",
"output_type": "stream",
"text": "will use cpu(0)\n"
}
]
```
## 定义模型
我们根据循环神经网络的计算表达式实现该模型。首先定义`init_rnn_state`函数来返回初始化的隐藏状态。它返回由一个形状为(`batch_size``num_hiddens`)的值为0的NDArray组成的元组。使用元组是为了更方便处理隐藏状态含有多个NDArray的情况。
......@@ -132,19 +96,6 @@ outputs, state_new = rnn(inputs, state, params)
len(outputs), outputs[0].shape, state_new[0].shape
```
```{.json .output n=7}
[
{
"data": {
"text/plain": "(5, (2, 1027), (2, 256))"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
```
## 定义预测函数
以下函数基于前缀`prefix`(含有数个字符的字符串)来预测接下来的`num_chars`个字符。这个函数稍显复杂,其中我们将循环神经单元`rnn`设置成了函数参数,这样在后面小节介绍其他循环神经网络时能重复使用这个函数。
......@@ -175,19 +126,6 @@ predict_rnn('分开', 10, rnn, params, init_rnn_state, num_hiddens, vocab_size,
ctx, idx_to_char, char_to_idx)
```
```{.json .output n=9}
[
{
"data": {
"text/plain": "'\u5206\u5f00\u6597\u4e24\u6696\u7238\u574a\u513f\u7b49\u4e0a\u5f77\u666f\u661f'"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
]
```
## 裁剪梯度
循环神经网络中较容易出现梯度衰减或爆炸,其原因我们会在[下一节](bptt.md)解释。为了应对梯度爆炸,我们可以裁剪梯度(clipping gradient)。假设我们把所有模型参数梯度的元素拼接成一个向量 $\boldsymbol{g}$,并设裁剪的阈值是$\theta$。裁剪后梯度
......@@ -298,19 +236,9 @@ train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
prefixes)
```
```{.json .output n=13}
[
{
"name": "stdout",
"output_type": "stream",
"text": "epoch 50, perplexity 69.470822, time 1.64 sec\n - \u5206\u5f00 \u6211\u60f3\u8981\u8fd9\u60f3 \u6211\u6709\u4f60\u7684\u53ef\u7231 \u6211\u60f3\u8981\u4f60\u60f3 \u6211\u6709 \u4f60\u60f3 \u6211\u60f3 \u8fd9\u60f3 \u6211\u4e0d\u80fd \u60f3\u4f60 \u4f60\u7684\u4f60 \u4f60\u7684\u4f60 \u6211\n - \u4e0d\u5206\u5f00 \u6211\u60f3\u60f3\u4f60\u60f3 \u6211\u6709 \u4f60\u60f3 \u6211\u60f3 \u8fd9\u60f3 \u6211\u4e0d\u80fd \u60f3\u4f60 \u4f60\u7684\u4f60 \u4f60\u7684\u4f60 \u6211\u6709 \u8fd9\u751f \u6211\u6709\u8981\u4f60 \u6211\u6709\u60f3\nepoch 100, perplexity 9.787908, time 1.66 sec\n - \u5206\u5f00 \u6211\u60f3\u60f3\u597d\u751f\u6d3b \u4e0d\u77e5\u4e0d\u89c9 \u4f60\u5df2\u7ecf\u79bb\u5f00\u6211 \u4e0d\u77e5\u4e0d\u89c9 \u6211\u5df2\u4e86\u8fd9\u751f\u594f \u6211\u77e5\u9053\u597d\u751f\u5c0f \u4e0d\u77e5\u4e0d\u89c9 \u4f60\u5df2\u7ecf\u79bb\u5f00\u6211 \n - \u4e0d\u5206\u5f00\u67f3 \u6211\u4e0d\u80fd\u518d\u60f3\u4f60 \u4e0d\u77e5\u4e0d\u89c9 \u4f60\u5df2\u7ecf\u79bb\u5f00\u6211 \u4e0d\u77e5\u4e0d\u89c9 \u6211\u5df2\u4e86\u8fd9\u751f\u594f \u6211\u77e5\u9053\u597d\u751f\u5c0f \u4e0d\u77e5\u4e0d\u89c9 \u4f60\u5df2\u7ecf\u79bb\u5f00\u6211\nepoch 150, perplexity 2.838760, time 1.70 sec\n - \u5206\u5f00 \u5feb\u4f7f\u7528\u53cc\u622a\u68cd \u54fc\u591a\u54c8 \u5a18\u624b\u8d70 \u6211\u60f3\u5c31\u8fd9\u6837\u7275\u7740\u4f60\u7684\u624b\u4e0d\u653e\u5f00 \u7231\u80fd\u4e0d\u80fd\u591f\u6c38\u8fdc\u5355\u7eaf\u6ca1\u6709\u60b2\u5bb3 \u6211 \u60f3\u5e26\u4f60\u9a91\u5355\n - \u4e0d\u5206\u5f00\u5417 \u6211\u53eb\u4f60\u7238 \u4f60\u6253\u6211\u6709 \u8fd9\u6837\u7b11\u4e00\u53ea\u534a\u4f1a \u6211\u4e0d\u80fd\u518d\u8fdc\u7275\u770b\u8457 \u6211\u7231 \u6211\u4e0d\u80fd \u60f3\u60c5\u5c31\u7684\u592a\u5feb\u5c31\u50cf\u9f99\u5377\u98ce \u4e0d\u80fd\u627f\nepoch 200, perplexity 1.563728, time 1.61 sec\n - \u5206\u5f00 \u5feb\u65f6\u7684\u5728\u7b49\u8457 \u6709\u4f60\u4e86 \u6709\u679c\u5e03\u542c\u4e86\u5427? \u6211\u7ed9\u800d\u7684\u8ba9\u6a21\u6709\u6837 \u4ec0\u4e48\u5175\u5668\u6700\u559c\u6b22 \u53cc\u622a\u68cd\u67d4\u4e2d\u5e26\u521a \u60f3\u8981\u53bb\u6cb3\u5357\n - \u4e0d\u5206\u5f00\u671f \u6211\u53eb\u4f60\u7238 \u4f60\u6253\u6211\u5988 \u8fd9\u6837\u5bf9\u5417\u5e72\u561b\u8fd9\u6837 \u4f55\u5fc5\u8ba9\u9152\u7275\u9f3b\u5b50\u8d70 \u778e \u8bf4\u5e95\u6211\u7684\u80a9\u8180 \u4f60 \u5728\u6211\u80f8\u53e3\u7761\u8457 \u50cf\u8fd9\u6837\n"
}
]
```
接下来采用相邻采样训练模型并创作歌词。
```{.python .input n=14}
```{.python .input n=19}
train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
vocab_size, ctx, corpus_indices, idx_to_char,
char_to_idx, False, num_epochs, num_steps, lr,
......@@ -318,31 +246,6 @@ train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
prefixes)
```
```{.json .output n=14}
[
{
"name": "stdout",
"output_type": "stream",
"text": "epoch 50, perplexity 59.698101, time 1.62 sec\n - \u5206\u5f00 \u6211\u4e0d\u8981\u518d\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231\n - \u4e0d\u5206\u5f00 \u6211\u4e0d\u8981\u8fd9\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231 \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d \u6211\u4e0d\u8981\u518d\u7231\n"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-5dcd34b7b161>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mchar_to_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mclipping_theta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_period\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_len\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m prefixes)\n\u001b[0m",
"\u001b[0;32m<ipython-input-11-54792c9ce232>\u001b[0m in \u001b[0;36mtrain_and_predict_rnn\u001b[0;34m(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, ctx, corpus_indices, idx_to_char, char_to_idx, is_random_iter, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;31m# \u88c1\u526a\u68af\u5ea6\u540e\u4f7f\u7528 SGD \u66f4\u65b0\u6743\u91cd\u3002\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mgrad_clipping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclipping_theta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0mgb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msgd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# \u56e0\u4e3a\u5df2\u7ecf\u8bef\u5dee\u53d6\u8fc7\u5747\u503c\uff0c\u68af\u5ea6\u4e0d\u7528\u518d\u505a\u5e73\u5747\u3002\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mloss_sum\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masscalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-10-2034c1ff1fda>\u001b[0m in \u001b[0;36mgrad_clipping\u001b[0;34m(params, theta, ctx)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnorm\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mnorm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnorm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masscalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnorm\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py\u001b[0m in \u001b[0;36masscalar\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1892\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1893\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The current array is not a scalar\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1894\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1895\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1896\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py\u001b[0m in \u001b[0;36masnumpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1874\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1875\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_as\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_void_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1876\u001b[0;31m ctypes.c_size_t(data.size)))\n\u001b[0m\u001b[1;32m 1877\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1878\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
```
## 小结
* 我们可以应用基于字符级循环神经网络的语言模型来创作歌词。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册