Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
f3b77bfe
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3b77bfe
编写于
9月 04, 2020
作者:
J
jzhang533
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
seq2seq with attention updated (#884)
上级
8dda1694
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
101 addition
and
136 deletion
+101
-136
paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
..._docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
+101
-136
未找到文件。
paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
浏览文件 @
f3b77bfe
...
...
@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count":
3
,
"execution_count":
2
,
"metadata": {},
"outputs": [
{
...
...
@@ -54,26 +54,32 @@
},
{
"cell_type": "code",
"execution_count":
2
,
"execution_count":
3
,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2020-09-04 14:06:10-- https://www.manythings.org/anki/cmn-eng.zip\n",
"Resolving www.manythings.org (www.manythings.org)... 104.24.108.196, 104.24.109.196, 172.67.173.198, ...\n",
"Connecting to www.manythings.org (www.manythings.org)|104.24.108.196|:443... connected.\n",
"HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n",
"--2020-09-04 16:13:35-- https://www.manythings.org/anki/cmn-eng.zip\n",
"Resolving www.manythings.org (www.manythings.org)... 104.24.109.196, 172.67.173.198, 2606:4700:3037::6818:6cc4, ...\n",
"Connecting to www.manythings.org (www.manythings.org)|104.24.109.196|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1030722 (1007K) [application/zip]\n",
"Saving to: ‘cmn-eng.zip’\n",
"\n",
"
The file is already fully retrieved; nothing to do.
\n",
"
cmn-eng.zip 100%[===================>] 1007K 520KB/s in 1.9s
\n",
"\n",
"Archive: cmn-eng.zip\n"
"2020-09-04 16:13:38 (520 KB/s) - ‘cmn-eng.zip’ saved [1030722/1030722]\n",
"\n",
"Archive: cmn-eng.zip\n",
" inflating: cmn.txt \n",
" inflating: _about.txt \n"
]
}
],
"source": [
"!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip
-f
cmn-eng.zip"
"!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip cmn-eng.zip"
]
},
{
...
...
@@ -108,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count":
5
,
"execution_count":
3
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -117,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count":
6
,
"execution_count":
4
,
"metadata": {},
"outputs": [
{
...
...
@@ -139,7 +145,6 @@
}
],
"source": [
"\n",
"lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n",
"words_re = re.compile(r'\\w+')\n",
"\n",
...
...
@@ -155,8 +160,7 @@
" if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n",
" x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n",
" filtered_pairs.append(x)\n",
"\n",
" \n",
" \n",
"print(len(filtered_pairs))\n",
"for x in filtered_pairs[:10]: print(x) "
]
...
...
@@ -177,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count":
7
,
"execution_count":
5
,
"metadata": {},
"outputs": [
{
...
...
@@ -193,14 +197,11 @@
"en_vocab = {}\n",
"cn_vocab = {}\n",
"\n",
"# create special token for
unkown
, begin of sentence, end of sentence\n",
"# create special token for
pad
, begin of sentence, end of sentence\n",
"en_vocab['<pad>'], en_vocab['<bos>'], en_vocab['<eos>'] = 0, 1, 2\n",
"cn_vocab['<pad>'], cn_vocab['<bos>'], cn_vocab['<eos>'] = 0, 1, 2\n",
"\n",
"#print(en_vocab, cn_vocab)\n",
"\n",
"en_idx, cn_idx = 3, 3\n",
"\n",
"for en, cn in filtered_pairs:\n",
" for w in en: \n",
" if w not in en_vocab: \n",
...
...
@@ -229,7 +230,7 @@
},
{
"cell_type": "code",
"execution_count":
11
,
"execution_count":
6
,
"metadata": {},
"outputs": [
{
...
...
@@ -243,7 +244,6 @@
}
],
"source": [
"# create padded datasets\n",
"padded_en_sents = []\n",
"padded_cn_sents = []\n",
"padded_cn_label_sents = []\n",
...
...
@@ -262,7 +262,6 @@
"train_cn_sents = np.array(padded_cn_sents)\n",
"train_cn_label_sents = np.array(padded_cn_label_sents)\n",
"\n",
"\n",
"print(train_en_sents.shape)\n",
"print(train_cn_sents.shape)\n",
"print(train_cn_label_sents.shape)"
...
...
@@ -280,7 +279,7 @@
},
{
"cell_type": "code",
"execution_count":
12
,
"execution_count":
7
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -289,7 +288,7 @@
"num_encoder_lstm_layers = 1\n",
"en_vocab_size = len(list(en_vocab))\n",
"cn_vocab_size = len(list(cn_vocab))\n",
"epochs =
3
0\n",
"epochs =
2
0\n",
"batch_size = 16"
]
},
...
...
@@ -301,12 +300,12 @@
"\n",
"在编码器的部分,我们通过查找完Embedding之后接一个LSTM的方式构建一个对源语言编码的网络。飞桨的RNN系列的API,除了LSTM之外,还提供了SimleRNN, GRU供使用,同时,还可以使用反向RNN,双向RNN,多层RNN等形式。也可以通过`dropout`参数设置是否对多层RNN的中间层进行`dropout`处理,来防止过拟合。\n",
"\n",
"除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过
集成
RNNCellBase来实现自己的RNN计算单元。"
"除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过
继承
RNNCellBase来实现自己的RNN计算单元。"
]
},
{
"cell_type": "code",
"execution_count":
16
,
"execution_count":
8
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -340,7 +339,7 @@
},
{
"cell_type": "code",
"execution_count":
1
9,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -350,11 +349,9 @@
" def __init__(self):\n",
" super(AttentionDecoder, self).__init__()\n",
" self.emb = paddle.nn.Embedding(cn_vocab_size, embedding_size)\n",
" \n",
" # the lstm layer for to generate target sentence representation\n",
" self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n",
" hidden_size=hidden_size)\n",
"
\n",
"\n",
" # for computing attention weights\n",
" self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n",
" self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n",
...
...
@@ -362,7 +359,6 @@
" # for computing output logits\n",
" self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n",
"\n",
"\n",
" def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n",
" x = self.emb(x)\n",
" \n",
...
...
@@ -376,7 +372,6 @@
" attention_logits = self.attention_linear2(attention_hidden)\n",
" attention_logits = paddle.squeeze(attention_logits)\n",
"\n",
" \n",
" attention_weights = F.softmax(attention_logits) \n",
" attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n",
" encoder_outputs)\n",
...
...
@@ -418,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count":
20
,
"execution_count":
11
,
"metadata": {},
"outputs": [
{
...
...
@@ -426,95 +421,65 @@
"output_type": "stream",
"text": [
"epoch:0\n",
"iter 0, loss:[7.61
8719
]\n",
"iter 200, loss:[
2.9712436
]\n",
"iter 0, loss:[7.61
94725
]\n",
"iter 200, loss:[
3.4147663
]\n",
"epoch:1\n",
"iter 0, loss:[
2.926154
]\n",
"iter 200, loss:[2.
8847036
]\n",
"iter 0, loss:[
3.0931656
]\n",
"iter 200, loss:[2.
7543137
]\n",
"epoch:2\n",
"iter 0, loss:[2.
9981458
]\n",
"iter 200, loss:[
3.099761
]\n",
"iter 0, loss:[2.
8413522
]\n",
"iter 200, loss:[
2.340513
]\n",
"epoch:3\n",
"iter 0, loss:[2.
6152773
]\n",
"iter 200, loss:[2.5
736806
]\n",
"iter 0, loss:[2.
597812
]\n",
"iter 200, loss:[2.5
552855
]\n",
"epoch:4\n",
"iter 0, loss:[2.
418916
]\n",
"iter 200, loss:[2.
020410
5]\n",
"iter 0, loss:[2.
0783448
]\n",
"iter 200, loss:[2.
454478
5]\n",
"epoch:5\n",
"iter 0, loss:[
2.0660372
]\n",
"iter 200, loss:[1.
997014
]\n",
"iter 0, loss:[
1.8709135
]\n",
"iter 200, loss:[1.
8736631
]\n",
"epoch:6\n",
"iter 0, loss:[1.
7394348
]\n",
"iter 200, loss:[
1.9713217
]\n",
"iter 0, loss:[1.
9589291
]\n",
"iter 200, loss:[
2.119414
]\n",
"epoch:7\n",
"iter 0, loss:[
2.2450879
]\n",
"iter 200, loss:[1.
8005365
]\n",
"iter 0, loss:[
1.5829577
]\n",
"iter 200, loss:[1.
6002902
]\n",
"epoch:8\n",
"iter 0, loss:[1.
7562586
]\n",
"iter 200, loss:[1.
8237668
]\n",
"iter 0, loss:[1.
6022769
]\n",
"iter 200, loss:[1.
52694
]\n",
"epoch:9\n",
"iter 0, loss:[1.36
32518
]\n",
"iter 200, loss:[1.
641327
3]\n",
"iter 0, loss:[1.36
16685
]\n",
"iter 200, loss:[1.
542044
3]\n",
"epoch:10\n",
"iter 0, loss:[1.0
960134
]\n",
"iter 200, loss:[1.
4547268
]\n",
"iter 0, loss:[1.0
397792
]\n",
"iter 200, loss:[1.
2458231
]\n",
"epoch:11\n",
"iter 0, loss:[1.
4081496
]\n",
"iter 200, loss:[1.4
078153
]\n",
"iter 0, loss:[1.
2107158
]\n",
"iter 200, loss:[1.4
26417
]\n",
"epoch:12\n",
"iter 0, loss:[1.1
659987
]\n",
"iter 200, loss:[1.
185811
4]\n",
"iter 0, loss:[1.1
840894
]\n",
"iter 200, loss:[1.
099966
4]\n",
"epoch:13\n",
"iter 0, loss:[1.
3759178
]\n",
"iter 200, loss:[
1.2046292
]\n",
"iter 0, loss:[1.
0968472
]\n",
"iter 200, loss:[
0.8149167
]\n",
"epoch:14\n",
"iter 0, loss:[0.
8987882
]\n",
"iter 200, loss:[1.
1897587
]\n",
"iter 0, loss:[0.
95585203
]\n",
"iter 200, loss:[1.
0070628
]\n",
"epoch:15\n",
"iter 0, loss:[0.8
3738756
]\n",
"iter 200, loss:[0.
78109366
]\n",
"iter 0, loss:[0.8
9463925
]\n",
"iter 200, loss:[0.
8288595
]\n",
"epoch:16\n",
"iter 0, loss:[0.
84268856
]\n",
"iter 200, loss:[0.
9557387
]\n",
"iter 0, loss:[0.
5672495
]\n",
"iter 200, loss:[0.
7317069
]\n",
"epoch:17\n",
"iter 0, loss:[0.
64364
7]\n",
"iter 200, loss:[0.
9286504
]\n",
"iter 0, loss:[0.
7678517
7]\n",
"iter 200, loss:[0.
5319323
]\n",
"epoch:18\n",
"iter 0, loss:[0.5
729206
]\n",
"iter 200, loss:[0.
6324647
]\n",
"iter 0, loss:[0.5
250005
]\n",
"iter 200, loss:[0.
4182841
]\n",
"epoch:19\n",
"iter 0, loss:[0.6614718]\n",
"iter 200, loss:[0.5292754]\n",
"epoch:20\n",
"iter 0, loss:[0.45713213]\n",
"iter 200, loss:[0.6192503]\n",
"epoch:21\n",
"iter 0, loss:[0.36670336]\n",
"iter 200, loss:[0.41927388]\n",
"epoch:22\n",
"iter 0, loss:[0.3294798]\n",
"iter 200, loss:[0.4599006]\n",
"epoch:23\n",
"iter 0, loss:[0.29158494]\n",
"iter 200, loss:[0.27783182]\n",
"epoch:24\n",
"iter 0, loss:[0.24686475]\n",
"iter 200, loss:[0.34916434]\n",
"epoch:25\n",
"iter 0, loss:[0.26881775]\n",
"iter 200, loss:[0.2400788]\n",
"epoch:26\n",
"iter 0, loss:[0.20649]\n",
"iter 200, loss:[0.212987]\n",
"epoch:27\n",
"iter 0, loss:[0.12560298]\n",
"iter 200, loss:[0.17958683]\n",
"epoch:28\n",
"iter 0, loss:[0.13129365]\n",
"iter 200, loss:[0.14788578]\n",
"epoch:29\n",
"iter 0, loss:[0.07885154]\n",
"iter 200, loss:[0.14729765]\n"
"iter 0, loss:[0.52320284]\n",
"iter 200, loss:[0.47618982]\n"
]
}
],
...
...
@@ -542,7 +507,7 @@
" x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
" x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
"\n",
" # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here)
*
hidden_size)\n",
" # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here)
,
hidden_size)\n",
" hidden = paddle.zeros([batch_size, 1, hidden_size])\n",
" cell = paddle.zeros([batch_size, 1, hidden_size])\n",
"\n",
...
...
@@ -573,48 +538,49 @@
"source": [
"# 使用模型进行机器翻译\n",
"\n",
"根据你所使用的计算设备的不同,上面的训练过程可能需要不等的时间。(在一台Mac笔记本上,大约耗时15~20分钟)\n",
"完成上面的模型训练之后,我们可以得到一个能够从英文翻译成中文的机器翻译模型。接下来我们通过一个greedy search来实现使用该模型完成实际的机器翻译。(实际的任务中,你可能需要用beam search算法来提升效果)"
]
},
{
"cell_type": "code",
"execution_count":
29
,
"execution_count":
18
,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"
he is poor
\n",
"true:
他很穷
。\n",
"pred:
他很穷
。\n",
"i
lent him a cd
\n",
"true: 我
借给他一盘CD
。\n",
"pred: 我
借给他一盘CD
。\n",
"
i m not so brave
\n",
"true:
我没那么勇敢
。\n",
"pred:
我没那么勇敢
。\n",
"
he goes to bed at eight o clock
\n",
"true:
他八點上床睡覺
。\n",
"pred:
他八點鐘也會遲到
。\n",
"i
know how old you are
\n",
"true: 我
知道你多大
了。\n",
"pred: 我
知道你多大
了。\n",
"
i m a detective
\n",
"true:
我是个侦探
。\n",
"pred:
我是个侦探
。\n",
"
i am the fastest runner
\n",
"true: 我
是跑得最快的人
。\n",
"pred: 我
是最快的跑者
。\n",
"
he got down the book from the shelf
\n",
"true:
他從架上拿下書
。\n",
"pred:
他從架上拿下書
。\n",
"he
arrived at the station at seven
\n",
"true: 他
7点到了火车站
。\n",
"pred: 他
7点到了火车站
。\n",
"he
fell down on the floor
\n",
"true: 他
摔倒在地
。\n",
"pred: 他
摔倒在地
。\n"
"
i agree with him
\n",
"true:
我同意他
。\n",
"pred:
我同意他
。\n",
"i
think i ll take a bath tonight
\n",
"true: 我
想我今晚會洗澡
。\n",
"pred: 我
想我今晚會洗澡
。\n",
"
he asked for a drink of water
\n",
"true:
他要了水喝
。\n",
"pred:
他喝了一杯水
。\n",
"
i began running
\n",
"true:
我開始跑
。\n",
"pred:
我開始跑
。\n",
"i
m sick
\n",
"true: 我
生病
了。\n",
"pred: 我
生病
了。\n",
"
you had better go to the dentist s
\n",
"true:
你最好去看牙醫
。\n",
"pred:
你最好去看牙醫
。\n",
"
we went for a walk in the forest
\n",
"true: 我
们去了林中散步
。\n",
"pred: 我
們去公园散步
。\n",
"
you ve arrived very early
\n",
"true:
你來得很早
。\n",
"pred:
你去早个
。\n",
"he
pretended not to be listening
\n",
"true: 他
裝作沒在聽
。\n",
"pred: 他
假装聽到它
。\n",
"he
always wanted to study japanese
\n",
"true: 他
一直想學日語
。\n",
"pred: 他
一直想學日語
。\n"
]
}
],
...
...
@@ -640,7 +606,6 @@
"decoded_sent = []\n",
"for i in range(MAX_LEN + 2):\n",
" logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n",
"\n",
" word = paddle.argmax(logits, axis=1)\n",
" decoded_sent.append(word.numpy())\n",
" word = paddle.unsqueeze(word, axis=-1)\n",
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录