Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
f3b77bfe
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3b77bfe
编写于
9月 04, 2020
作者:
J
jzhang533
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
seq2seq with attention updated (#884)
上级
8dda1694
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
101 addition
and
136 deletion
+101
-136
paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
..._docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
+101
-136
未找到文件。
paddle2.0_docs/seq2seq_with_attention/seq2seq_with_attention.ipynb
浏览文件 @
f3b77bfe
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
3
,
"execution_count":
2
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
...
@@ -54,26 +54,32 @@
...
@@ -54,26 +54,32 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
2
,
"execution_count":
3
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
"name": "stdout",
"name": "stdout",
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"--2020-09-04 14:06:10-- https://www.manythings.org/anki/cmn-eng.zip\n",
"--2020-09-04 16:13:35-- https://www.manythings.org/anki/cmn-eng.zip\n",
"Resolving www.manythings.org (www.manythings.org)... 104.24.108.196, 104.24.109.196, 172.67.173.198, ...\n",
"Resolving www.manythings.org (www.manythings.org)... 104.24.109.196, 172.67.173.198, 2606:4700:3037::6818:6cc4, ...\n",
"Connecting to www.manythings.org (www.manythings.org)|104.24.108.196|:443... connected.\n",
"Connecting to www.manythings.org (www.manythings.org)|104.24.109.196|:443... connected.\n",
"HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1030722 (1007K) [application/zip]\n",
"Saving to: ‘cmn-eng.zip’\n",
"\n",
"\n",
"
The file is already fully retrieved; nothing to do.
\n",
"
cmn-eng.zip 100%[===================>] 1007K 520KB/s in 1.9s
\n",
"\n",
"\n",
"Archive: cmn-eng.zip\n"
"2020-09-04 16:13:38 (520 KB/s) - ‘cmn-eng.zip’ saved [1030722/1030722]\n",
"\n",
"Archive: cmn-eng.zip\n",
" inflating: cmn.txt \n",
" inflating: _about.txt \n"
]
]
}
}
],
],
"source": [
"source": [
"!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip
-f
cmn-eng.zip"
"!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip cmn-eng.zip"
]
]
},
},
{
{
...
@@ -108,7 +114,7 @@
...
@@ -108,7 +114,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
5
,
"execution_count":
3
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
...
@@ -117,7 +123,7 @@
...
@@ -117,7 +123,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
6
,
"execution_count":
4
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
...
@@ -139,7 +145,6 @@
...
@@ -139,7 +145,6 @@
}
}
],
],
"source": [
"source": [
"\n",
"lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n",
"lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n",
"words_re = re.compile(r'\\w+')\n",
"words_re = re.compile(r'\\w+')\n",
"\n",
"\n",
...
@@ -155,8 +160,7 @@
...
@@ -155,8 +160,7 @@
" if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n",
" if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n",
" x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n",
" x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n",
" filtered_pairs.append(x)\n",
" filtered_pairs.append(x)\n",
"\n",
" \n",
" \n",
"print(len(filtered_pairs))\n",
"print(len(filtered_pairs))\n",
"for x in filtered_pairs[:10]: print(x) "
"for x in filtered_pairs[:10]: print(x) "
]
]
...
@@ -177,7 +181,7 @@
...
@@ -177,7 +181,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
7
,
"execution_count":
5
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
...
@@ -193,14 +197,11 @@
...
@@ -193,14 +197,11 @@
"en_vocab = {}\n",
"en_vocab = {}\n",
"cn_vocab = {}\n",
"cn_vocab = {}\n",
"\n",
"\n",
"# create special token for
unkown
, begin of sentence, end of sentence\n",
"# create special token for
pad
, begin of sentence, end of sentence\n",
"en_vocab['<pad>'], en_vocab['<bos>'], en_vocab['<eos>'] = 0, 1, 2\n",
"en_vocab['<pad>'], en_vocab['<bos>'], en_vocab['<eos>'] = 0, 1, 2\n",
"cn_vocab['<pad>'], cn_vocab['<bos>'], cn_vocab['<eos>'] = 0, 1, 2\n",
"cn_vocab['<pad>'], cn_vocab['<bos>'], cn_vocab['<eos>'] = 0, 1, 2\n",
"\n",
"\n",
"#print(en_vocab, cn_vocab)\n",
"\n",
"en_idx, cn_idx = 3, 3\n",
"en_idx, cn_idx = 3, 3\n",
"\n",
"for en, cn in filtered_pairs:\n",
"for en, cn in filtered_pairs:\n",
" for w in en: \n",
" for w in en: \n",
" if w not in en_vocab: \n",
" if w not in en_vocab: \n",
...
@@ -229,7 +230,7 @@
...
@@ -229,7 +230,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
11
,
"execution_count":
6
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
...
@@ -243,7 +244,6 @@
...
@@ -243,7 +244,6 @@
}
}
],
],
"source": [
"source": [
"# create padded datasets\n",
"padded_en_sents = []\n",
"padded_en_sents = []\n",
"padded_cn_sents = []\n",
"padded_cn_sents = []\n",
"padded_cn_label_sents = []\n",
"padded_cn_label_sents = []\n",
...
@@ -262,7 +262,6 @@
...
@@ -262,7 +262,6 @@
"train_cn_sents = np.array(padded_cn_sents)\n",
"train_cn_sents = np.array(padded_cn_sents)\n",
"train_cn_label_sents = np.array(padded_cn_label_sents)\n",
"train_cn_label_sents = np.array(padded_cn_label_sents)\n",
"\n",
"\n",
"\n",
"print(train_en_sents.shape)\n",
"print(train_en_sents.shape)\n",
"print(train_cn_sents.shape)\n",
"print(train_cn_sents.shape)\n",
"print(train_cn_label_sents.shape)"
"print(train_cn_label_sents.shape)"
...
@@ -280,7 +279,7 @@
...
@@ -280,7 +279,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
12
,
"execution_count":
7
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
...
@@ -289,7 +288,7 @@
...
@@ -289,7 +288,7 @@
"num_encoder_lstm_layers = 1\n",
"num_encoder_lstm_layers = 1\n",
"en_vocab_size = len(list(en_vocab))\n",
"en_vocab_size = len(list(en_vocab))\n",
"cn_vocab_size = len(list(cn_vocab))\n",
"cn_vocab_size = len(list(cn_vocab))\n",
"epochs =
3
0\n",
"epochs =
2
0\n",
"batch_size = 16"
"batch_size = 16"
]
]
},
},
...
@@ -301,12 +300,12 @@
...
@@ -301,12 +300,12 @@
"\n",
"\n",
"在编码器的部分,我们通过查找完Embedding之后接一个LSTM的方式构建一个对源语言编码的网络。飞桨的RNN系列的API,除了LSTM之外,还提供了SimleRNN, GRU供使用,同时,还可以使用反向RNN,双向RNN,多层RNN等形式。也可以通过`dropout`参数设置是否对多层RNN的中间层进行`dropout`处理,来防止过拟合。\n",
"在编码器的部分,我们通过查找完Embedding之后接一个LSTM的方式构建一个对源语言编码的网络。飞桨的RNN系列的API,除了LSTM之外,还提供了SimleRNN, GRU供使用,同时,还可以使用反向RNN,双向RNN,多层RNN等形式。也可以通过`dropout`参数设置是否对多层RNN的中间层进行`dropout`处理,来防止过拟合。\n",
"\n",
"\n",
"除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过
集成
RNNCellBase来实现自己的RNN计算单元。"
"除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过
继承
RNNCellBase来实现自己的RNN计算单元。"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
16
,
"execution_count":
8
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
...
@@ -340,7 +339,7 @@
...
@@ -340,7 +339,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
1
9,
"execution_count": 9,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
...
@@ -350,11 +349,9 @@
...
@@ -350,11 +349,9 @@
" def __init__(self):\n",
" def __init__(self):\n",
" super(AttentionDecoder, self).__init__()\n",
" super(AttentionDecoder, self).__init__()\n",
" self.emb = paddle.nn.Embedding(cn_vocab_size, embedding_size)\n",
" self.emb = paddle.nn.Embedding(cn_vocab_size, embedding_size)\n",
" \n",
" # the lstm layer for to generate target sentence representation\n",
" self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n",
" self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n",
" hidden_size=hidden_size)\n",
" hidden_size=hidden_size)\n",
"
\n",
"\n",
" # for computing attention weights\n",
" # for computing attention weights\n",
" self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n",
" self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n",
" self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n",
" self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n",
...
@@ -362,7 +359,6 @@
...
@@ -362,7 +359,6 @@
" # for computing output logits\n",
" # for computing output logits\n",
" self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n",
" self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n",
"\n",
"\n",
"\n",
" def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n",
" def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n",
" x = self.emb(x)\n",
" x = self.emb(x)\n",
" \n",
" \n",
...
@@ -376,7 +372,6 @@
...
@@ -376,7 +372,6 @@
" attention_logits = self.attention_linear2(attention_hidden)\n",
" attention_logits = self.attention_linear2(attention_hidden)\n",
" attention_logits = paddle.squeeze(attention_logits)\n",
" attention_logits = paddle.squeeze(attention_logits)\n",
"\n",
"\n",
" \n",
" attention_weights = F.softmax(attention_logits) \n",
" attention_weights = F.softmax(attention_logits) \n",
" attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n",
" attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n",
" encoder_outputs)\n",
" encoder_outputs)\n",
...
@@ -418,7 +413,7 @@
...
@@ -418,7 +413,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
20
,
"execution_count":
11
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
...
@@ -426,95 +421,65 @@
...
@@ -426,95 +421,65 @@
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"epoch:0\n",
"epoch:0\n",
"iter 0, loss:[7.61
8719
]\n",
"iter 0, loss:[7.61
94725
]\n",
"iter 200, loss:[
2.9712436
]\n",
"iter 200, loss:[
3.4147663
]\n",
"epoch:1\n",
"epoch:1\n",
"iter 0, loss:[
2.926154
]\n",
"iter 0, loss:[
3.0931656
]\n",
"iter 200, loss:[2.
8847036
]\n",
"iter 200, loss:[2.
7543137
]\n",
"epoch:2\n",
"epoch:2\n",
"iter 0, loss:[2.
9981458
]\n",
"iter 0, loss:[2.
8413522
]\n",
"iter 200, loss:[
3.099761
]\n",
"iter 200, loss:[
2.340513
]\n",
"epoch:3\n",
"epoch:3\n",
"iter 0, loss:[2.
6152773
]\n",
"iter 0, loss:[2.
597812
]\n",
"iter 200, loss:[2.5
736806
]\n",
"iter 200, loss:[2.5
552855
]\n",
"epoch:4\n",
"epoch:4\n",
"iter 0, loss:[2.
418916
]\n",
"iter 0, loss:[2.
0783448
]\n",
"iter 200, loss:[2.
020410
5]\n",
"iter 200, loss:[2.
454478
5]\n",
"epoch:5\n",
"epoch:5\n",
"iter 0, loss:[
2.0660372
]\n",
"iter 0, loss:[
1.8709135
]\n",
"iter 200, loss:[1.
997014
]\n",
"iter 200, loss:[1.
8736631
]\n",
"epoch:6\n",
"epoch:6\n",
"iter 0, loss:[1.
7394348
]\n",
"iter 0, loss:[1.
9589291
]\n",
"iter 200, loss:[
1.9713217
]\n",
"iter 200, loss:[
2.119414
]\n",
"epoch:7\n",
"epoch:7\n",
"iter 0, loss:[
2.2450879
]\n",
"iter 0, loss:[
1.5829577
]\n",
"iter 200, loss:[1.
8005365
]\n",
"iter 200, loss:[1.
6002902
]\n",
"epoch:8\n",
"epoch:8\n",
"iter 0, loss:[1.
7562586
]\n",
"iter 0, loss:[1.
6022769
]\n",
"iter 200, loss:[1.
8237668
]\n",
"iter 200, loss:[1.
52694
]\n",
"epoch:9\n",
"epoch:9\n",
"iter 0, loss:[1.36
32518
]\n",
"iter 0, loss:[1.36
16685
]\n",
"iter 200, loss:[1.
641327
3]\n",
"iter 200, loss:[1.
542044
3]\n",
"epoch:10\n",
"epoch:10\n",
"iter 0, loss:[1.0
960134
]\n",
"iter 0, loss:[1.0
397792
]\n",
"iter 200, loss:[1.
4547268
]\n",
"iter 200, loss:[1.
2458231
]\n",
"epoch:11\n",
"epoch:11\n",
"iter 0, loss:[1.
4081496
]\n",
"iter 0, loss:[1.
2107158
]\n",
"iter 200, loss:[1.4
078153
]\n",
"iter 200, loss:[1.4
26417
]\n",
"epoch:12\n",
"epoch:12\n",
"iter 0, loss:[1.1
659987
]\n",
"iter 0, loss:[1.1
840894
]\n",
"iter 200, loss:[1.
185811
4]\n",
"iter 200, loss:[1.
099966
4]\n",
"epoch:13\n",
"epoch:13\n",
"iter 0, loss:[1.
3759178
]\n",
"iter 0, loss:[1.
0968472
]\n",
"iter 200, loss:[
1.2046292
]\n",
"iter 200, loss:[
0.8149167
]\n",
"epoch:14\n",
"epoch:14\n",
"iter 0, loss:[0.
8987882
]\n",
"iter 0, loss:[0.
95585203
]\n",
"iter 200, loss:[1.
1897587
]\n",
"iter 200, loss:[1.
0070628
]\n",
"epoch:15\n",
"epoch:15\n",
"iter 0, loss:[0.8
3738756
]\n",
"iter 0, loss:[0.8
9463925
]\n",
"iter 200, loss:[0.
78109366
]\n",
"iter 200, loss:[0.
8288595
]\n",
"epoch:16\n",
"epoch:16\n",
"iter 0, loss:[0.
84268856
]\n",
"iter 0, loss:[0.
5672495
]\n",
"iter 200, loss:[0.
9557387
]\n",
"iter 200, loss:[0.
7317069
]\n",
"epoch:17\n",
"epoch:17\n",
"iter 0, loss:[0.
64364
7]\n",
"iter 0, loss:[0.
7678517
7]\n",
"iter 200, loss:[0.
9286504
]\n",
"iter 200, loss:[0.
5319323
]\n",
"epoch:18\n",
"epoch:18\n",
"iter 0, loss:[0.5
729206
]\n",
"iter 0, loss:[0.5
250005
]\n",
"iter 200, loss:[0.
6324647
]\n",
"iter 200, loss:[0.
4182841
]\n",
"epoch:19\n",
"epoch:19\n",
"iter 0, loss:[0.6614718]\n",
"iter 0, loss:[0.52320284]\n",
"iter 200, loss:[0.5292754]\n",
"iter 200, loss:[0.47618982]\n"
"epoch:20\n",
"iter 0, loss:[0.45713213]\n",
"iter 200, loss:[0.6192503]\n",
"epoch:21\n",
"iter 0, loss:[0.36670336]\n",
"iter 200, loss:[0.41927388]\n",
"epoch:22\n",
"iter 0, loss:[0.3294798]\n",
"iter 200, loss:[0.4599006]\n",
"epoch:23\n",
"iter 0, loss:[0.29158494]\n",
"iter 200, loss:[0.27783182]\n",
"epoch:24\n",
"iter 0, loss:[0.24686475]\n",
"iter 200, loss:[0.34916434]\n",
"epoch:25\n",
"iter 0, loss:[0.26881775]\n",
"iter 200, loss:[0.2400788]\n",
"epoch:26\n",
"iter 0, loss:[0.20649]\n",
"iter 200, loss:[0.212987]\n",
"epoch:27\n",
"iter 0, loss:[0.12560298]\n",
"iter 200, loss:[0.17958683]\n",
"epoch:28\n",
"iter 0, loss:[0.13129365]\n",
"iter 200, loss:[0.14788578]\n",
"epoch:29\n",
"iter 0, loss:[0.07885154]\n",
"iter 200, loss:[0.14729765]\n"
]
]
}
}
],
],
...
@@ -542,7 +507,7 @@
...
@@ -542,7 +507,7 @@
" x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
" x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
" x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
" x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
"\n",
"\n",
" # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here)
*
hidden_size)\n",
" # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here)
,
hidden_size)\n",
" hidden = paddle.zeros([batch_size, 1, hidden_size])\n",
" hidden = paddle.zeros([batch_size, 1, hidden_size])\n",
" cell = paddle.zeros([batch_size, 1, hidden_size])\n",
" cell = paddle.zeros([batch_size, 1, hidden_size])\n",
"\n",
"\n",
...
@@ -573,48 +538,49 @@
...
@@ -573,48 +538,49 @@
"source": [
"source": [
"# 使用模型进行机器翻译\n",
"# 使用模型进行机器翻译\n",
"\n",
"\n",
"根据你所使用的计算设备的不同,上面的训练过程可能需要不等的时间。(在一台Mac笔记本上,大约耗时15~20分钟)\n",
"完成上面的模型训练之后,我们可以得到一个能够从英文翻译成中文的机器翻译模型。接下来我们通过一个greedy search来实现使用该模型完成实际的机器翻译。(实际的任务中,你可能需要用beam search算法来提升效果)"
"完成上面的模型训练之后,我们可以得到一个能够从英文翻译成中文的机器翻译模型。接下来我们通过一个greedy search来实现使用该模型完成实际的机器翻译。(实际的任务中,你可能需要用beam search算法来提升效果)"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
29
,
"execution_count":
18
,
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
"name": "stdout",
"name": "stdout",
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"
he is poor
\n",
"
i agree with him
\n",
"true:
他很穷
。\n",
"true:
我同意他
。\n",
"pred:
他很穷
。\n",
"pred:
我同意他
。\n",
"i
lent him a cd
\n",
"i
think i ll take a bath tonight
\n",
"true: 我
借给他一盘CD
。\n",
"true: 我
想我今晚會洗澡
。\n",
"pred: 我
借给他一盘CD
。\n",
"pred: 我
想我今晚會洗澡
。\n",
"
i m not so brave
\n",
"
he asked for a drink of water
\n",
"true:
我没那么勇敢
。\n",
"true:
他要了水喝
。\n",
"pred:
我没那么勇敢
。\n",
"pred:
他喝了一杯水
。\n",
"
he goes to bed at eight o clock
\n",
"
i began running
\n",
"true:
他八點上床睡覺
。\n",
"true:
我開始跑
。\n",
"pred:
他八點鐘也會遲到
。\n",
"pred:
我開始跑
。\n",
"i
know how old you are
\n",
"i
m sick
\n",
"true: 我
知道你多大
了。\n",
"true: 我
生病
了。\n",
"pred: 我
知道你多大
了。\n",
"pred: 我
生病
了。\n",
"
i m a detective
\n",
"
you had better go to the dentist s
\n",
"true:
我是个侦探
。\n",
"true:
你最好去看牙醫
。\n",
"pred:
我是个侦探
。\n",
"pred:
你最好去看牙醫
。\n",
"
i am the fastest runner
\n",
"
we went for a walk in the forest
\n",
"true: 我
是跑得最快的人
。\n",
"true: 我
们去了林中散步
。\n",
"pred: 我
是最快的跑者
。\n",
"pred: 我
們去公园散步
。\n",
"
he got down the book from the shelf
\n",
"
you ve arrived very early
\n",
"true:
他從架上拿下書
。\n",
"true:
你來得很早
。\n",
"pred:
他從架上拿下書
。\n",
"pred:
你去早个
。\n",
"he
arrived at the station at seven
\n",
"he
pretended not to be listening
\n",
"true: 他
7点到了火车站
。\n",
"true: 他
裝作沒在聽
。\n",
"pred: 他
7点到了火车站
。\n",
"pred: 他
假装聽到它
。\n",
"he
fell down on the floor
\n",
"he
always wanted to study japanese
\n",
"true: 他
摔倒在地
。\n",
"true: 他
一直想學日語
。\n",
"pred: 他
摔倒在地
。\n"
"pred: 他
一直想學日語
。\n"
]
]
}
}
],
],
...
@@ -640,7 +606,6 @@
...
@@ -640,7 +606,6 @@
"decoded_sent = []\n",
"decoded_sent = []\n",
"for i in range(MAX_LEN + 2):\n",
"for i in range(MAX_LEN + 2):\n",
" logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n",
" logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n",
"\n",
" word = paddle.argmax(logits, axis=1)\n",
" word = paddle.argmax(logits, axis=1)\n",
" decoded_sent.append(word.numpy())\n",
" decoded_sent.append(word.numpy())\n",
" word = paddle.unsqueeze(word, axis=-1)\n",
" word = paddle.unsqueeze(word, axis=-1)\n",
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录