未验证 提交 f3b77bfe 编写于 作者: J jzhang533 提交者: GitHub

seq2seq with attention updated (#884)

上级 8dda1694
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -54,26 +54,32 @@ ...@@ -54,26 +54,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"--2020-09-04 14:06:10-- https://www.manythings.org/anki/cmn-eng.zip\n", "--2020-09-04 16:13:35-- https://www.manythings.org/anki/cmn-eng.zip\n",
"Resolving www.manythings.org (www.manythings.org)... 104.24.108.196, 104.24.109.196, 172.67.173.198, ...\n", "Resolving www.manythings.org (www.manythings.org)... 104.24.109.196, 172.67.173.198, 2606:4700:3037::6818:6cc4, ...\n",
"Connecting to www.manythings.org (www.manythings.org)|104.24.108.196|:443... connected.\n", "Connecting to www.manythings.org (www.manythings.org)|104.24.109.196|:443... connected.\n",
"HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", "HTTP request sent, awaiting response... 200 OK\n",
"Length: 1030722 (1007K) [application/zip]\n",
"Saving to: ‘cmn-eng.zip’\n",
"\n", "\n",
" The file is already fully retrieved; nothing to do.\n", "cmn-eng.zip 100%[===================>] 1007K 520KB/s in 1.9s \n",
"\n", "\n",
"Archive: cmn-eng.zip\n" "2020-09-04 16:13:38 (520 KB/s) - ‘cmn-eng.zip’ saved [1030722/1030722]\n",
"\n",
"Archive: cmn-eng.zip\n",
" inflating: cmn.txt \n",
" inflating: _about.txt \n"
] ]
} }
], ],
"source": [ "source": [
"!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip -f cmn-eng.zip" "!wget -c https://www.manythings.org/anki/cmn-eng.zip && unzip cmn-eng.zip"
] ]
}, },
{ {
...@@ -108,7 +114,7 @@ ...@@ -108,7 +114,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -117,7 +123,7 @@ ...@@ -117,7 +123,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -139,7 +145,6 @@ ...@@ -139,7 +145,6 @@
} }
], ],
"source": [ "source": [
"\n",
"lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n", "lines = open('cmn.txt', encoding='utf-8').read().strip().split('\\n')\n",
"words_re = re.compile(r'\\w+')\n", "words_re = re.compile(r'\\w+')\n",
"\n", "\n",
...@@ -155,8 +160,7 @@ ...@@ -155,8 +160,7 @@
" if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n", " if len(x[0]) < MAX_LEN and len(x[1]) < MAX_LEN and \\\n",
" x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n", " x[0][0] in ('i', 'you', 'he', 'she', 'we', 'they'):\n",
" filtered_pairs.append(x)\n", " filtered_pairs.append(x)\n",
"\n", " \n",
" \n",
"print(len(filtered_pairs))\n", "print(len(filtered_pairs))\n",
"for x in filtered_pairs[:10]: print(x) " "for x in filtered_pairs[:10]: print(x) "
] ]
...@@ -177,7 +181,7 @@ ...@@ -177,7 +181,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -193,14 +197,11 @@ ...@@ -193,14 +197,11 @@
"en_vocab = {}\n", "en_vocab = {}\n",
"cn_vocab = {}\n", "cn_vocab = {}\n",
"\n", "\n",
"# create special token for unkown, begin of sentence, end of sentence\n", "# create special token for pad, begin of sentence, end of sentence\n",
"en_vocab['<pad>'], en_vocab['<bos>'], en_vocab['<eos>'] = 0, 1, 2\n", "en_vocab['<pad>'], en_vocab['<bos>'], en_vocab['<eos>'] = 0, 1, 2\n",
"cn_vocab['<pad>'], cn_vocab['<bos>'], cn_vocab['<eos>'] = 0, 1, 2\n", "cn_vocab['<pad>'], cn_vocab['<bos>'], cn_vocab['<eos>'] = 0, 1, 2\n",
"\n", "\n",
"#print(en_vocab, cn_vocab)\n",
"\n",
"en_idx, cn_idx = 3, 3\n", "en_idx, cn_idx = 3, 3\n",
"\n",
"for en, cn in filtered_pairs:\n", "for en, cn in filtered_pairs:\n",
" for w in en: \n", " for w in en: \n",
" if w not in en_vocab: \n", " if w not in en_vocab: \n",
...@@ -229,7 +230,7 @@ ...@@ -229,7 +230,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -243,7 +244,6 @@ ...@@ -243,7 +244,6 @@
} }
], ],
"source": [ "source": [
"# create padded datasets\n",
"padded_en_sents = []\n", "padded_en_sents = []\n",
"padded_cn_sents = []\n", "padded_cn_sents = []\n",
"padded_cn_label_sents = []\n", "padded_cn_label_sents = []\n",
...@@ -262,7 +262,6 @@ ...@@ -262,7 +262,6 @@
"train_cn_sents = np.array(padded_cn_sents)\n", "train_cn_sents = np.array(padded_cn_sents)\n",
"train_cn_label_sents = np.array(padded_cn_label_sents)\n", "train_cn_label_sents = np.array(padded_cn_label_sents)\n",
"\n", "\n",
"\n",
"print(train_en_sents.shape)\n", "print(train_en_sents.shape)\n",
"print(train_cn_sents.shape)\n", "print(train_cn_sents.shape)\n",
"print(train_cn_label_sents.shape)" "print(train_cn_label_sents.shape)"
...@@ -280,7 +279,7 @@ ...@@ -280,7 +279,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -289,7 +288,7 @@ ...@@ -289,7 +288,7 @@
"num_encoder_lstm_layers = 1\n", "num_encoder_lstm_layers = 1\n",
"en_vocab_size = len(list(en_vocab))\n", "en_vocab_size = len(list(en_vocab))\n",
"cn_vocab_size = len(list(cn_vocab))\n", "cn_vocab_size = len(list(cn_vocab))\n",
"epochs = 30\n", "epochs = 20\n",
"batch_size = 16" "batch_size = 16"
] ]
}, },
...@@ -301,12 +300,12 @@ ...@@ -301,12 +300,12 @@
"\n", "\n",
"在编码器的部分,我们通过查找完Embedding之后接一个LSTM的方式构建一个对源语言编码的网络。飞桨的RNN系列的API,除了LSTM之外,还提供了SimleRNN, GRU供使用,同时,还可以使用反向RNN,双向RNN,多层RNN等形式。也可以通过`dropout`参数设置是否对多层RNN的中间层进行`dropout`处理,来防止过拟合。\n", "在编码器的部分,我们通过查找完Embedding之后接一个LSTM的方式构建一个对源语言编码的网络。飞桨的RNN系列的API,除了LSTM之外,还提供了SimleRNN, GRU供使用,同时,还可以使用反向RNN,双向RNN,多层RNN等形式。也可以通过`dropout`参数设置是否对多层RNN的中间层进行`dropout`处理,来防止过拟合。\n",
"\n", "\n",
"除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过集成RNNCellBase来实现自己的RNN计算单元。" "除了使用序列到序列的RNN操作之外,也可以通过SimpleRNN, GRUCell, LSTMCell等API更灵活的创建单步的RNN计算,甚至通过继承RNNCellBase来实现自己的RNN计算单元。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -340,7 +339,7 @@ ...@@ -340,7 +339,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -350,11 +349,9 @@ ...@@ -350,11 +349,9 @@
" def __init__(self):\n", " def __init__(self):\n",
" super(AttentionDecoder, self).__init__()\n", " super(AttentionDecoder, self).__init__()\n",
" self.emb = paddle.nn.Embedding(cn_vocab_size, embedding_size)\n", " self.emb = paddle.nn.Embedding(cn_vocab_size, embedding_size)\n",
" \n",
" # the lstm layer for to generate target sentence representation\n",
" self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n", " self.lstm = paddle.nn.LSTM(input_size=embedding_size + hidden_size, \n",
" hidden_size=hidden_size)\n", " hidden_size=hidden_size)\n",
" \n", "\n",
" # for computing attention weights\n", " # for computing attention weights\n",
" self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n", " self.attention_linear1 = paddle.nn.Linear(hidden_size * 2, hidden_size)\n",
" self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n", " self.attention_linear2 = paddle.nn.Linear(hidden_size, 1)\n",
...@@ -362,7 +359,6 @@ ...@@ -362,7 +359,6 @@
" # for computing output logits\n", " # for computing output logits\n",
" self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n", " self.outlinear =paddle.nn.Linear(hidden_size, cn_vocab_size)\n",
"\n", "\n",
"\n",
" def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n", " def forward(self, x, previous_hidden, previous_cell, encoder_outputs):\n",
" x = self.emb(x)\n", " x = self.emb(x)\n",
" \n", " \n",
...@@ -376,7 +372,6 @@ ...@@ -376,7 +372,6 @@
" attention_logits = self.attention_linear2(attention_hidden)\n", " attention_logits = self.attention_linear2(attention_hidden)\n",
" attention_logits = paddle.squeeze(attention_logits)\n", " attention_logits = paddle.squeeze(attention_logits)\n",
"\n", "\n",
" \n",
" attention_weights = F.softmax(attention_logits) \n", " attention_weights = F.softmax(attention_logits) \n",
" attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n", " attention_weights = paddle.expand_as(paddle.unsqueeze(attention_weights, -1), \n",
" encoder_outputs)\n", " encoder_outputs)\n",
...@@ -418,7 +413,7 @@ ...@@ -418,7 +413,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -426,95 +421,65 @@ ...@@ -426,95 +421,65 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch:0\n", "epoch:0\n",
"iter 0, loss:[7.618719]\n", "iter 0, loss:[7.6194725]\n",
"iter 200, loss:[2.9712436]\n", "iter 200, loss:[3.4147663]\n",
"epoch:1\n", "epoch:1\n",
"iter 0, loss:[2.926154]\n", "iter 0, loss:[3.0931656]\n",
"iter 200, loss:[2.8847036]\n", "iter 200, loss:[2.7543137]\n",
"epoch:2\n", "epoch:2\n",
"iter 0, loss:[2.9981458]\n", "iter 0, loss:[2.8413522]\n",
"iter 200, loss:[3.099761]\n", "iter 200, loss:[2.340513]\n",
"epoch:3\n", "epoch:3\n",
"iter 0, loss:[2.6152773]\n", "iter 0, loss:[2.597812]\n",
"iter 200, loss:[2.5736806]\n", "iter 200, loss:[2.5552855]\n",
"epoch:4\n", "epoch:4\n",
"iter 0, loss:[2.418916]\n", "iter 0, loss:[2.0783448]\n",
"iter 200, loss:[2.0204105]\n", "iter 200, loss:[2.4544785]\n",
"epoch:5\n", "epoch:5\n",
"iter 0, loss:[2.0660372]\n", "iter 0, loss:[1.8709135]\n",
"iter 200, loss:[1.997014]\n", "iter 200, loss:[1.8736631]\n",
"epoch:6\n", "epoch:6\n",
"iter 0, loss:[1.7394348]\n", "iter 0, loss:[1.9589291]\n",
"iter 200, loss:[1.9713217]\n", "iter 200, loss:[2.119414]\n",
"epoch:7\n", "epoch:7\n",
"iter 0, loss:[2.2450879]\n", "iter 0, loss:[1.5829577]\n",
"iter 200, loss:[1.8005365]\n", "iter 200, loss:[1.6002902]\n",
"epoch:8\n", "epoch:8\n",
"iter 0, loss:[1.7562586]\n", "iter 0, loss:[1.6022769]\n",
"iter 200, loss:[1.8237668]\n", "iter 200, loss:[1.52694]\n",
"epoch:9\n", "epoch:9\n",
"iter 0, loss:[1.3632518]\n", "iter 0, loss:[1.3616685]\n",
"iter 200, loss:[1.6413273]\n", "iter 200, loss:[1.5420443]\n",
"epoch:10\n", "epoch:10\n",
"iter 0, loss:[1.0960134]\n", "iter 0, loss:[1.0397792]\n",
"iter 200, loss:[1.4547268]\n", "iter 200, loss:[1.2458231]\n",
"epoch:11\n", "epoch:11\n",
"iter 0, loss:[1.4081496]\n", "iter 0, loss:[1.2107158]\n",
"iter 200, loss:[1.4078153]\n", "iter 200, loss:[1.426417]\n",
"epoch:12\n", "epoch:12\n",
"iter 0, loss:[1.1659987]\n", "iter 0, loss:[1.1840894]\n",
"iter 200, loss:[1.1858114]\n", "iter 200, loss:[1.0999664]\n",
"epoch:13\n", "epoch:13\n",
"iter 0, loss:[1.3759178]\n", "iter 0, loss:[1.0968472]\n",
"iter 200, loss:[1.2046292]\n", "iter 200, loss:[0.8149167]\n",
"epoch:14\n", "epoch:14\n",
"iter 0, loss:[0.8987882]\n", "iter 0, loss:[0.95585203]\n",
"iter 200, loss:[1.1897587]\n", "iter 200, loss:[1.0070628]\n",
"epoch:15\n", "epoch:15\n",
"iter 0, loss:[0.83738756]\n", "iter 0, loss:[0.89463925]\n",
"iter 200, loss:[0.78109366]\n", "iter 200, loss:[0.8288595]\n",
"epoch:16\n", "epoch:16\n",
"iter 0, loss:[0.84268856]\n", "iter 0, loss:[0.5672495]\n",
"iter 200, loss:[0.9557387]\n", "iter 200, loss:[0.7317069]\n",
"epoch:17\n", "epoch:17\n",
"iter 0, loss:[0.643647]\n", "iter 0, loss:[0.76785177]\n",
"iter 200, loss:[0.9286504]\n", "iter 200, loss:[0.5319323]\n",
"epoch:18\n", "epoch:18\n",
"iter 0, loss:[0.5729206]\n", "iter 0, loss:[0.5250005]\n",
"iter 200, loss:[0.6324647]\n", "iter 200, loss:[0.4182841]\n",
"epoch:19\n", "epoch:19\n",
"iter 0, loss:[0.6614718]\n", "iter 0, loss:[0.52320284]\n",
"iter 200, loss:[0.5292754]\n", "iter 200, loss:[0.47618982]\n"
"epoch:20\n",
"iter 0, loss:[0.45713213]\n",
"iter 200, loss:[0.6192503]\n",
"epoch:21\n",
"iter 0, loss:[0.36670336]\n",
"iter 200, loss:[0.41927388]\n",
"epoch:22\n",
"iter 0, loss:[0.3294798]\n",
"iter 200, loss:[0.4599006]\n",
"epoch:23\n",
"iter 0, loss:[0.29158494]\n",
"iter 200, loss:[0.27783182]\n",
"epoch:24\n",
"iter 0, loss:[0.24686475]\n",
"iter 200, loss:[0.34916434]\n",
"epoch:25\n",
"iter 0, loss:[0.26881775]\n",
"iter 200, loss:[0.2400788]\n",
"epoch:26\n",
"iter 0, loss:[0.20649]\n",
"iter 200, loss:[0.212987]\n",
"epoch:27\n",
"iter 0, loss:[0.12560298]\n",
"iter 200, loss:[0.17958683]\n",
"epoch:28\n",
"iter 0, loss:[0.13129365]\n",
"iter 200, loss:[0.14788578]\n",
"epoch:29\n",
"iter 0, loss:[0.07885154]\n",
"iter 200, loss:[0.14729765]\n"
] ]
} }
], ],
...@@ -542,7 +507,7 @@ ...@@ -542,7 +507,7 @@
" x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n", " x_cn_data = train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
" x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n", " x_cn_label_data = train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration+1))]\n",
"\n", "\n",
" # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here) * hidden_size)\n", " # shape: (batch, num_layer(=1 here) * num_of_direction(=1 here), hidden_size)\n",
" hidden = paddle.zeros([batch_size, 1, hidden_size])\n", " hidden = paddle.zeros([batch_size, 1, hidden_size])\n",
" cell = paddle.zeros([batch_size, 1, hidden_size])\n", " cell = paddle.zeros([batch_size, 1, hidden_size])\n",
"\n", "\n",
...@@ -573,48 +538,49 @@ ...@@ -573,48 +538,49 @@
"source": [ "source": [
"# 使用模型进行机器翻译\n", "# 使用模型进行机器翻译\n",
"\n", "\n",
"根据你所使用的计算设备的不同,上面的训练过程可能需要不等的时间。(在一台Mac笔记本上,大约耗时15~20分钟)\n",
"完成上面的模型训练之后,我们可以得到一个能够从英文翻译成中文的机器翻译模型。接下来我们通过一个greedy search来实现使用该模型完成实际的机器翻译。(实际的任务中,你可能需要用beam search算法来提升效果)" "完成上面的模型训练之后,我们可以得到一个能够从英文翻译成中文的机器翻译模型。接下来我们通过一个greedy search来实现使用该模型完成实际的机器翻译。(实际的任务中,你可能需要用beam search算法来提升效果)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"he is poor\n", "i agree with him\n",
"true: 他很穷。\n", "true: 我同意他。\n",
"pred: 他很穷。\n", "pred: 我同意他。\n",
"i lent him a cd\n", "i think i ll take a bath tonight\n",
"true: 我借给他一盘CD。\n", "true: 我想我今晚會洗澡。\n",
"pred: 我借给他一盘CD。\n", "pred: 我想我今晚會洗澡。\n",
"i m not so brave\n", "he asked for a drink of water\n",
"true: 我没那么勇敢。\n", "true: 他要了水喝。\n",
"pred: 我没那么勇敢。\n", "pred: 他喝了一杯水。\n",
"he goes to bed at eight o clock\n", "i began running\n",
"true: 他八點上床睡覺。\n", "true: 我開始跑。\n",
"pred: 他八點鐘也會遲到。\n", "pred: 我開始跑。\n",
"i know how old you are\n", "i m sick\n",
"true: 我知道你多大了。\n", "true: 我生病了。\n",
"pred: 我知道你多大了。\n", "pred: 我生病了。\n",
"i m a detective\n", "you had better go to the dentist s\n",
"true: 我是个侦探。\n", "true: 你最好去看牙醫。\n",
"pred: 我是个侦探。\n", "pred: 你最好去看牙醫。\n",
"i am the fastest runner\n", "we went for a walk in the forest\n",
"true: 我是跑得最快的人。\n", "true: 我们去了林中散步。\n",
"pred: 我是最快的跑者。\n", "pred: 我們去公园散步。\n",
"he got down the book from the shelf\n", "you ve arrived very early\n",
"true: 他從架上拿下書。\n", "true: 你來得很早。\n",
"pred: 他從架上拿下書。\n", "pred: 你去早个。\n",
"he arrived at the station at seven\n", "he pretended not to be listening\n",
"true: 他7点到了火车站。\n", "true: 他裝作沒在聽。\n",
"pred: 他7点到了火车站。\n", "pred: 他假装聽到它。\n",
"he fell down on the floor\n", "he always wanted to study japanese\n",
"true: 他摔倒在地。\n", "true: 他一直想學日語。\n",
"pred: 他摔倒在地。\n" "pred: 他一直想學日語。\n"
] ]
} }
], ],
...@@ -640,7 +606,6 @@ ...@@ -640,7 +606,6 @@
"decoded_sent = []\n", "decoded_sent = []\n",
"for i in range(MAX_LEN + 2):\n", "for i in range(MAX_LEN + 2):\n",
" logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n", " logits, (hidden, cell) = atten_decoder(word, hidden, cell, en_repr)\n",
"\n",
" word = paddle.argmax(logits, axis=1)\n", " word = paddle.argmax(logits, axis=1)\n",
" decoded_sent.append(word.numpy())\n", " decoded_sent.append(word.numpy())\n",
" word = paddle.unsqueeze(word, axis=-1)\n", " word = paddle.unsqueeze(word, axis=-1)\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册