提交 078ba9a4 编写于 作者: S shusentang

fix bug

上级 0863af17
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6.4 循环神经网络的从零开始实现"
]
},
{
"cell_type": "code",
"execution_count": 1,
......@@ -9,8 +16,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n",
"cpu\n"
"0.4.0\n",
"cuda\n"
]
}
],
......@@ -34,14 +41,19 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.1 one-hot向量"
]
},
{
"cell_type": "code",
"execution_count": 3,
......@@ -50,8 +62,8 @@
{
"data": {
"text/plain": [
"tensor([[1., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 1., ..., 0., 0., 0.]])"
"tensor([[ 1., 0., 0., ..., 0., 0., 0.],\n",
" [ 0., 0., 1., ..., 0., 0., 0.]])"
]
},
"execution_count": 3,
......@@ -63,7 +75,7 @@
"def one_hot(x, n_class, dtype=torch.float32): \n",
" # X shape: (batch), output shape: (batch, n_class)\n",
" x = x.long()\n",
" res = torch.zeros(x.shape[0], n_class, dtype=dtype)\n",
" res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)\n",
" res.scatter_(1, x.view(-1, 1), 1)\n",
" return res\n",
" \n",
......@@ -94,6 +106,13 @@
"print(len(inputs), inputs[0].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.2 初始化模型参数"
]
},
{
"cell_type": "code",
"execution_count": 5,
......@@ -103,7 +122,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"will use cpu\n"
"will use cuda\n"
]
}
],
......@@ -126,12 +145,17 @@
" return nn.ParameterList([W_xh, W_hh, b_h, W_hq, b_q])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.3 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"def init_rnn_state(batch_size, num_hiddens, device):\n",
......@@ -141,9 +165,7 @@
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"def rnn(inputs, state, params):\n",
......@@ -179,12 +201,17 @@
"print(len(outputs), outputs[0].shape, state_new[0].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.4 定义预测函数"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
......@@ -213,7 +240,7 @@
{
"data": {
"text/plain": [
"'分开爽只忆干走抄蝴配碑鹰'"
"'分开西圈绪升王凝瓜必客映'"
]
},
"execution_count": 10,
......@@ -226,12 +253,17 @@
" device, idx_to_char, char_to_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.5 裁剪梯度"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": true
},
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
......@@ -245,12 +277,18 @@
" param.grad.data *= (theta / norm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.6 困惑度\n",
"## 6.4.7 定义模型训练函数"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": true
},
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# 本函数已保存在d2lzh包中方便以后使用\n",
......@@ -285,7 +323,7 @@
" outputs = torch.cat(outputs, dim=0)\n",
" # Y的形状是(batch_size, num_steps),转置后再变成长度为\n",
" # batch * num_steps 的向量,这样跟输出的行一一对应\n",
" y = torch.transpose(Y, 0, 1).flatten()\n",
" y = torch.transpose(Y, 0, 1).contiguous().view(-1)\n",
" # 使用交叉熵损失计算平均分类误差\n",
" l = loss(outputs, y.long())\n",
" \n",
......@@ -307,12 +345,17 @@
" num_hiddens, vocab_size, device, idx_to_char, char_to_idx))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6.4.8 训练模型并创作歌词"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": true
},
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"num_epochs, num_steps, batch_size, lr, clipping_theta = 250, 35, 32, 1e2, 1e-2\n",
......@@ -321,28 +364,28 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 50, perplexity 69.368188, time 0.70 sec\n",
" - 分开 我不要的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 坏\n",
" - 不分开 快颗的双 我不定空 我有一场的溪 一知哈觉 我不要 别怪两 我给就的可爱女人 坏坏的让我疯狂的可爱\n",
"epoch 100, perplexity 10.391687, time 0.57 sec\n",
" - 分开 一颗在双截棍 哼哼哈 一直两 我想就这样牵着你的手不放开 爱能不能够永远单纯没有 看星形 一颗两颗\n",
" - 不分开吗 我不要再想 我不能再想 我不能再想 我不能再想 我不能再想 我不能再想 我不能再想 我不能再想 \n",
"epoch 150, perplexity 2.892629, time 0.58 sec\n",
" - 分开 一颗会 干什么 我满就这手牵着你的手不放开 爱能不能够永远单纯没有悲哀 我 想带你骑单车 我 想和\n",
" - 不分开吗 我后能爸 你知我妈 这样了直不屈 一身正气 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使\n",
"epoch 200, perplexity 1.587798, time 0.59 sec\n",
" - 分开 一只会美牵球 我哼得起国 我灵魂失控 黑云在降落 我被它拖着走 如果我遇见你是一场悲剧 我可以让生\n",
" - 不分开吗 我叫你爸 你打我妈 这样对吗干嘛这样 何必让酒牵鼻子走 瞎 说也你在旧每 太的话说的忧剩 回说林\n",
"epoch 250, perplexity 1.336056, time 0.59 sec\n",
" - 分开 我爱想很样力 你作 失去开的玩笑 想通 却又再考倒我 说散 你想很久了吧? 我的认画败给黑色幽默 \n",
" - 不分开扫把的胖女巫 用拉丁文念咒语啦啦呜 她养的黑猫笑起来像哭 啦啦啦呜 刻多将痛 每一秒钟 你给没中 不\n"
"epoch 50, perplexity 70.039647, time 0.11 sec\n",
" - 分开 我不要再想 我不能 想你的让我 我的可 你怎么 一颗四 一颗四 我不要 一颗两 一颗四 一颗四 我\n",
" - 不分开 我不要再 你你的外 在人 别你的让我 狂的可 语人两 我不要 一颗两 一颗四 一颗四 我不要 一\n",
"epoch 100, perplexity 9.726828, time 0.12 sec\n",
" - 分开 一直的美栈人 一起看 我不要好生活 你知不觉 我已好好生活 我知道好生活 后知不觉 我跟了这生活 \n",
" - 不分开堡 我不要再想 我不 我不 我不要再想你 不知不觉 你已经离开我 不知不觉 我跟了好生活 我知道好生\n",
"epoch 150, perplexity 2.864874, time 0.11 sec\n",
" - 分开 一只会停留 有不它元羞 这蝪什么奇怪的事都有 包括像猫的狗 印地安老斑鸠 平常话不多 除非是乌鸦抢\n",
" - 不分开扫 我不你再想 我不能再想 我不 我不 我不要再想你 不知不觉 你已经离开我 不知不觉 我跟了这节奏\n",
"epoch 200, perplexity 1.597790, time 0.11 sec\n",
" - 分开 有杰伦 干 载颗拳满的让空美空主 相爱还有个人 再狠狠忘记 你爱过我的证 有晶莹的手滴 让说些人\n",
" - 不分开扫 我叫你爸 你打我妈 这样对吗干嘛这样 何必让它牵鼻子走 瞎 说底牵打我妈要 难道球耳 快使用双截\n",
"epoch 250, perplexity 1.303903, time 0.12 sec\n",
" - 分开 有杰人开留 仙唱它怕羞 蜥蝪横著走 这里什么奇怪的事都有 包括像猫的狗 印地安老斑鸠 平常话不多 \n",
" - 不分开简 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不能\n"
]
}
],
......@@ -356,27 +399,27 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 50, perplexity 62.258360, time 0.60 sec\n",
" - 分开 我想要这 快使我有 我想一直 我想一空 我想一空 我想一空 我想一空 我想一空 我想一空 我想一空\n",
" - 不分开 我有你的可写 我有你的让写 我有你的 快使了双 我想一直 我想一空 我想一空 我想一空 我想一空 \n",
"epoch 100, perplexity 7.132725, time 0.59 sec\n",
" - 分开 一颗我 印地了人在江 还我在很医 我不要再想 我不 我不 我不要再想 我不 我不 我不要再想你 不\n",
" - 不分开觉你 我想要你想 我不要再想 我不 我不 我不要再想 我不 我不 我不要再想 我不 我不 我不要再想\n",
"epoch 150, perplexity 2.102736, time 0.65 sec\n",
" - 分开 问候我 谁地了的让 就颗莫衷身力 它所许秋 漫天黄沙凉过 塞北的客栈人多 牧草有没有 我马儿有些瘦\n",
" - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后生后觉 快使用双截棍 哼哼哈兮 \n",
"epoch 200, perplexity 1.317644, time 0.63 sec\n",
" - 分开 问候我 谁地我 陪你了那信堡我 甩开球我满腔的怒火 我想揍你已经很久 别想躲 说你眼睛看着我 别发\n",
"epoch 50, perplexity 59.514416, time 0.11 sec\n",
" - 分开 我想要这 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空\n",
" - 不分开 我不要这 全使了双 我想了这 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空 我想了空\n",
"epoch 100, perplexity 6.801417, time 0.11 sec\n",
" - 分开 我说的这样笑 想你都 不着我 我想就这样牵 你你的回不笑多难的 它在云实 有一条事 全你了空 \n",
" - 不分开觉 你已经离开我 不知不觉 我跟好这节活 我该好好生活 不知不觉 你跟了离开我 不知不觉 我跟好这节\n",
"epoch 150, perplexity 2.063730, time 0.16 sec\n",
" - 分开 我有到这样牵着你的手不放开 爱可不可以简简单单没有伤 古有你烦 我有多烦恼向 你知带悄 回我的外\n",
" - 不分开觉 你已经很个我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后哼哈兮 快使用双截棍 哼哼哈兮 \n",
"epoch 200, perplexity 1.300031, time 0.11 sec\n",
" - 分开 我想要这样牵着你的手不放开 爱能不能够永远单甜没有伤害 你 靠着我的肩膀 你 在我胸口睡著 像这样\n",
" - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生\n",
"epoch 250, perplexity 1.155166, time 0.70 sec\n",
" - 分开 问候我 谁地神枪在币 悲伤得的隐密 后录那这对里 藤蔓都靠我 你拿着球现投 又下会掩护我 选你这种\n",
"epoch 250, perplexity 1.164455, time 0.11 sec\n",
" - 分开 我有一这样布 对你依依不舍 连隔壁邻居都猜到我现在的感受 河边的风 在吹着头发飘动 牵着你的手 一\n",
" - 不分开觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生\n"
]
}
......@@ -392,9 +435,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}
......@@ -415,7 +456,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.6.4"
}
},
"nbformat": 4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册