提交 39328348 编写于 作者: S shusentang

fix bug #84

上级 b3401dd6
......@@ -21,7 +21,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.0.0 cuda\n"
"1.1.0 cuda\n"
]
}
],
......@@ -39,10 +39,10 @@
"sys.path.append(\"..\") \n",
"import d2lzh_pytorch as d2l\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"DATA_ROOT = \"/S1/CSCL/tangss/Datasets\"\n",
"DATA_ROOT = \"/data1/tangss/Datasets\"\n",
"\n",
"print(torch.__version__, device)"
]
......@@ -88,10 +88,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 12500/12500 [00:04<00:00, 2930.03it/s]\n",
"100%|██████████| 12500/12500 [00:04<00:00, 3008.48it/s]\n",
"100%|██████████| 12500/12500 [00:03<00:00, 3365.08it/s]\n",
"100%|██████████| 12500/12500 [00:03<00:00, 3305.63it/s]\n"
"100%|██████████| 12500/12500 [00:00<00:00, 34211.42it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 38506.48it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 31316.61it/s]\n",
"100%|██████████| 12500/12500 [00:00<00:00, 29664.72it/s]\n"
]
}
],
......@@ -108,7 +108,8 @@
" random.shuffle(data)\n",
" return data\n",
"\n",
"train_data, test_data = read_imdb('train'), read_imdb('test')"
"data_root = os.path.join(DATA_ROOT, \"aclImdb\")\n",
"train_data, test_data = read_imdb('train', data_root), read_imdb('test', data_root)"
]
},
{
......@@ -152,7 +153,7 @@
{
"data": {
"text/plain": [
"('# words in vocab:', 46151)"
"('# words in vocab:', 46152)"
]
},
"execution_count": 5,
......@@ -330,8 +331,7 @@
"ExecuteTime": {
"end_time": "2019-07-03T04:26:47.895604Z",
"start_time": "2019-07-03T04:26:47.685801Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
......@@ -345,10 +345,17 @@
"ExecuteTime": {
"end_time": "2019-07-03T04:26:48.102388Z",
"start_time": "2019-07-03T04:26:47.897582Z"
},
"collapsed": true
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 21202 oov words.\n"
]
}
],
"source": [
"def load_pretrained_embedding(words, pretrained_vocab):\n",
" \"\"\"从预训练好的vocab中提取出words对应的词向量\"\"\"\n",
......@@ -359,9 +366,9 @@
" idx = pretrained_vocab.stoi[word]\n",
" embed[i, :] = pretrained_vocab.vectors[idx]\n",
" except KeyError:\n",
" oov_count += 0\n",
" oov_count += 1\n",
" if oov_count > 0:\n",
" print(\"There are %d oov words.\")\n",
" print(\"There are %d oov words.\" % oov_count)\n",
" return embed\n",
"\n",
"net.embedding.weight.data.copy_(load_pretrained_embedding(vocab.itos, glove_vocab))\n",
......@@ -390,11 +397,11 @@
"output_type": "stream",
"text": [
"training on cuda\n",
"epoch 1, loss 0.5759, train acc 0.666, test acc 0.832, time 250.8 sec\n",
"epoch 2, loss 0.1785, train acc 0.842, test acc 0.852, time 253.3 sec\n",
"epoch 3, loss 0.1042, train acc 0.866, test acc 0.856, time 253.7 sec\n",
"epoch 4, loss 0.0682, train acc 0.888, test acc 0.868, time 254.2 sec\n",
"epoch 5, loss 0.0483, train acc 0.901, test acc 0.862, time 251.4 sec\n"
"epoch 1, loss 0.5415, train acc 0.719, test acc 0.819, time 48.7 sec\n",
"epoch 2, loss 0.1897, train acc 0.837, test acc 0.852, time 53.0 sec\n",
"epoch 3, loss 0.1105, train acc 0.857, test acc 0.844, time 51.6 sec\n",
"epoch 4, loss 0.0719, train acc 0.881, test acc 0.865, time 52.1 sec\n",
"epoch 5, loss 0.0519, train acc 0.894, test acc 0.852, time 51.2 sec\n"
]
}
],
......@@ -488,9 +495,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:py36]",
"display_name": "Python [conda env:py36_pytorch]",
"language": "python",
"name": "conda-env-py36-py"
"name": "conda-env-py36_pytorch-py"
},
"language_info": {
"codemirror_mode": {
......@@ -502,7 +509,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.6.2"
},
"varInspector": {
"cols": {
......
......@@ -1203,9 +1203,9 @@ def load_pretrained_embedding(words, pretrained_vocab):
idx = pretrained_vocab.stoi[word]
embed[i, :] = pretrained_vocab.vectors[idx]
except KeyError:
oov_count += 0
oov_count += 1
if oov_count > 0:
print("There are %d oov words.")
print("There are %d oov words." % oov_count)
return embed
def predict_sentiment(net, vocab, sentence):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册