imdb_bow_classification.ipynb 15.3 KB
Notebook
Newer Older
J
jzhang533 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IMDB 数据集使用BOW网络的文本分类\n",
    "\n",
    "本示例教程演示如何在IMDB数据集上用简单的BOW网络完成文本分类的任务。\n",
    "\n",
    "IMDB数据集是一个对电影评论标注为正向评论与负向评论的数据集,共有25000条文本数据作为训练集,25000条文本数据作为测试集。\n",
    "该数据集的官方地址为: http://ai.stanford.edu/~amaas/data/sentiment/\n",
    "\n",
    "- Warning: `paddle.dataset.imdb`先在是一个非常粗野的实现,后续需要有替代的方案。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 环境设置\n",
    "\n",
    "本示例基于飞桨开源框架2.0版本。"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
28
   "execution_count": 4,
J
jzhang533 已提交
29 30 31 32 33 34
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
J
jzhang533 已提交
35 36
      "0.0.0\n",
      "264e76cae6861ad9b1d4bcd8c3212f7a78c01e4d\n"
J
jzhang533 已提交
37 38 39 40 41 42 43
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "\n",
J
jzhang533 已提交
44
    "paddle.disable_static()\n",
J
jzhang533 已提交
45
    "print(paddle.__version__)\n",
J
jzhang533 已提交
46
    "print(paddle.__git_commit__)\n"
J
jzhang533 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载数据\n",
    "\n",
    "我们会使用`paddle.dataset`完成数据下载,构建字典和准备数据读取器。在飞桨2.0版本中,推荐使用padding的方式来对同一个batch中长度不一的数据进行补齐,所以在字典中,我们还会添加一个特殊的`<pad>`词,用来在后续对batch中较短的句子进行填充。"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
60
   "execution_count": 5,
J
jzhang533 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading IMDB word dict....\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading IMDB word dict....\")\n",
    "word_dict = paddle.dataset.imdb.word_dict()\n",
    "\n",
    "train_reader = paddle.dataset.imdb.train(word_dict)\n",
    "test_reader = paddle.dataset.imdb.test(word_dict)\n"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
81
   "execution_count": 6,
J
jzhang533 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the:0\n",
      "and:1\n",
      "a:2\n",
      "of:3\n",
      "to:4\n",
      "...\n",
      "virtual:5143\n",
      "warriors:5144\n",
      "widely:5145\n",
      "<unk>:5146\n",
      "<pad>:5147\n",
      "totally 5148 words\n"
     ]
    }
   ],
   "source": [
    "# add a pad token to the dict for later padding the sequence\n",
    "word_dict['<pad>'] = len(word_dict)\n",
    "\n",
    "for k in list(word_dict)[:5]:\n",
    "    print(\"{}:{}\".format(k.decode('ASCII'), word_dict[k]))\n",
    "\n",
    "print(\"...\")\n",
    "\n",
    "for k in list(word_dict)[-5:]:\n",
    "    print(\"{}:{}\".format(k if isinstance(k, str) else k.decode('ASCII'), word_dict[k]))\n",
    "\n",
    "print(\"totally {} words\".format(len(word_dict)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 参数设置\n",
    "\n",
    "在这里我们设置一下词表大小,`embedding`的大小,batch_size,等等"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
129
   "execution_count": 7,
J
jzhang533 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = len(word_dict)\n",
    "emb_size = 256\n",
    "seq_len = 200\n",
    "batch_size = 32\n",
    "epoch_num = 2\n",
    "pad_id = word_dict['<pad>']\n",
    "\n",
    "classes = ['negative', 'positive']\n",
    "\n",
    "def ids_to_str(ids):\n",
    "    #print(ids)\n",
    "    words = []\n",
    "    for k in ids:\n",
    "        w = list(word_dict)[k]\n",
    "        words.append(w if isinstance(w, str) else w.decode('ASCII'))\n",
    "    return \" \".join(words)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里,取出一条数据打印出来看看,可以对数据有一个初步直观的印象。"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
160
   "execution_count": 8,
J
jzhang533 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5146, 43, 71, 6, 1092, 14, 0, 878, 130, 151, 5146, 18, 281, 747, 0, 5146, 3, 5146, 2165, 37, 5146, 46, 5, 71, 4089, 377, 162, 46, 5, 32, 1287, 300, 35, 203, 2136, 565, 14, 2, 253, 26, 146, 61, 372, 1, 615, 5146, 5, 30, 0, 50, 3290, 6, 2148, 14, 0, 5146, 11, 17, 451, 24, 4, 127, 10, 0, 878, 130, 43, 2, 50, 5146, 751, 5146, 5, 2, 221, 3727, 6, 9, 1167, 373, 9, 5, 5146, 7, 5, 1343, 13, 2, 5146, 1, 250, 7, 98, 4270, 56, 2316, 0, 928, 11, 11, 9, 16, 5, 5146, 5146, 6, 50, 69, 27, 280, 27, 108, 1045, 0, 2633, 4177, 3180, 17, 1675, 1, 2571] 0\n",
      "<unk> has much in common with the third man another <unk> film set among the <unk> of <unk> europe like <unk> there is much inventive camera work there is an innocent american who gets emotionally involved with a woman he doesnt really understand and whose <unk> is all the more striking in contrast with the <unk> br but id have to say that the third man has a more <unk> storyline <unk> is a bit disjointed in this respect perhaps this is <unk> it is presented as a <unk> and making it too coherent would spoil the effect br br this movie is <unk> <unk> in more than one sense one never sees the sun shine grim but intriguing and frightening\n",
      "negative\n"
     ]
    }
   ],
   "source": [
    "# 取出来第一条数据看看样子。\n",
    "sent, label = next(train_reader())\n",
    "print(sent, label)\n",
    "\n",
    "print(ids_to_str(sent))\n",
    "print(classes[label])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 用padding的方式对齐数据\n",
    "\n",
    "文本数据中,每一句话的长度都是不一样的,为了方便后续的神经网络的计算,常见的处理方式是把数据集中的数据都统一成同样长度的数据。这包括:对于较长的数据进行截断处理,对于较短的数据用特殊的词`<pad>`进行填充。接下来的代码会对数据集中的数据进行这样的处理。"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
193
   "execution_count": 9,
J
jzhang533 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(25000, 200)\n",
      "(25000, 1)\n",
      "(25000, 200)\n",
      "(25000, 1)\n",
      "<unk> has much in common with the third man another <unk> film set among the <unk> of <unk> europe like <unk> there is much inventive camera work there is an innocent american who gets emotionally involved with a woman he doesnt really understand and whose <unk> is all the more striking in contrast with the <unk> br but id have to say that the third man has a more <unk> storyline <unk> is a bit disjointed in this respect perhaps this is <unk> it is presented as a <unk> and making it too coherent would spoil the effect br br this movie is <unk> <unk> in more than one sense one never sees the sun shine grim but intriguing and frightening <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n",
      "<unk> is the most original movie ive seen in years if you like unique thrillers that are influenced by film noir then this is just the right cure for all of those hollywood summer <unk> <unk> the theaters these days von <unk> <unk> like breaking the waves have gotten more <unk> but this is really his best work it is <unk> without being distracting and offers the perfect combination of suspense and dark humor its too bad he decided <unk> cameras were the wave of the future its hard to say who talked him away from the style he <unk> here but its everyones loss that he went into his heavily <unk> <unk> direction instead <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n",
      "<unk> von <unk> is never <unk> in trying out new techniques some of them are very original while others are best <unk> br he depicts <unk> germany as a <unk> train journey with so many cities lying in ruins <unk> <unk> a young american of german descent feels <unk> to help in their <unk> it is not a simple task as he quickly finds outbr br his uncle finds him a job as a night <unk> on the <unk> <unk> line his job is to <unk> to the needs of the passengers when the shoes are <unk> a <unk> mark is made on the <unk> a terrible argument <unk> when a passengers shoes are not <unk> despite the fact they have been <unk> there are many <unk> to the german <unk> of <unk> to such stupid <unk> br the <unk> journey is like an <unk> <unk> mans <unk> through life with all its <unk> and <unk> in one sequence <unk> <unk> through the back <unk> to discover them filled with <unk> bodies appearing to have just escaped from <unk> these images horrible as they are are <unk> as in a dream each with its own terrible impact yet <unk> br\n"
     ]
    }
   ],
   "source": [
    "def create_padded_dataset(reader):\n",
    "    padded_sents = []\n",
    "    labels = []\n",
    "    for batch_id, data in enumerate(reader):\n",
    "        sent, label = data\n",
    "        padded_sent = sent[:seq_len] + [pad_id] * (seq_len - len(sent))\n",
    "        padded_sents.append(padded_sent)\n",
    "        labels.append(label)\n",
    "    return np.array(padded_sents), np.expand_dims(np.array(labels), axis=1)\n",
    "\n",
    "train_sents, train_labels = create_padded_dataset(train_reader())\n",
    "test_sents, test_labels = create_padded_dataset(test_reader())\n",
    "\n",
    "print(train_sents.shape)\n",
    "print(train_labels.shape)\n",
    "print(test_sents.shape)\n",
    "print(test_labels.shape)\n",
    "\n",
    "for sent in train_sents[:3]:\n",
    "    print(ids_to_str(sent))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 组建网络\n",
    "\n",
    "本示例中,我们将会使用一个不考虑词的顺序的BOW的网络,在查找到每个词对应的embedding后,简单的取平均,作为一个句子的表示。然后用`Linear`进行线性变换。为了防止过拟合,我们还使用了`Dropout`。"
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
244
   "execution_count": 11,
J
jzhang533 已提交
245 246 247 248 249 250
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyNet(paddle.nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(MyNet, self).__init__()\n",
J
jzhang533 已提交
251 252 253
    "        self.emb = paddle.nn.Embedding(vocab_size, emb_size)\n",
    "        self.fc = paddle.nn.Linear(in_features=emb_size, out_features=2)\n",
    "        self.dropout = paddle.nn.Dropout(0.5)\n",
J
jzhang533 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266
    "\n",
    "    def forward(self, x):\n",
    "        x = self.emb(x)\n",
    "        x = paddle.reduce_mean(x, dim=1)\n",
    "        x = self.dropout(x)\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
J
jzhang533 已提交
267
    "# 开始模型的训练\n"
J
jzhang533 已提交
268 269 270 271
   ]
  },
  {
   "cell_type": "code",
J
jzhang533 已提交
272
   "execution_count": 13,
J
jzhang533 已提交
273 274 275 276 277 278
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
J
jzhang533 已提交
279 280 281 282 283 284
      "epoch: 0, batch_id: 0, loss is: [0.6926701]\n",
      "epoch: 0, batch_id: 500, loss is: [0.41248566]\n",
      "[validation] accuracy/loss: 0.8505121469497681/0.3615057170391083\n",
      "epoch: 1, batch_id: 0, loss is: [0.29521096]\n",
      "epoch: 1, batch_id: 500, loss is: [0.2916747]\n",
      "[validation] accuracy/loss: 0.86475670337677/0.3259459137916565\n"
J
jzhang533 已提交
285 286 287 288 289 290 291
     ]
    }
   ],
   "source": [
    "def train(model):\n",
    "    model.train()\n",
    "\n",
J
jzhang533 已提交
292
    "    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
J
jzhang533 已提交
293 294 295 296 297 298 299 300 301 302 303
    "\n",
    "    for epoch in range(epoch_num):\n",
    "        # shuffle data\n",
    "        perm = np.random.permutation(len(train_sents))\n",
    "        train_sents_shuffled = train_sents[perm]\n",
    "        train_labels_shuffled = train_labels[perm]\n",
    "        \n",
    "        for batch_id in range(len(train_sents_shuffled) // batch_size):\n",
    "            x_data = train_sents_shuffled[(batch_id * batch_size):((batch_id+1)*batch_size)]\n",
    "            y_data = train_labels_shuffled[(batch_id * batch_size):((batch_id+1)*batch_size)]\n",
    "            \n",
J
jzhang533 已提交
304 305
    "            sent = paddle.to_tensor(x_data)\n",
    "            label = paddle.to_tensor(y_data)\n",
J
jzhang533 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
    "            \n",
    "            logits = model(sent)\n",
    "            loss = paddle.nn.functional.softmax_with_cross_entropy(logits, label)\n",
    "            \n",
    "            avg_loss = paddle.mean(loss)\n",
    "            if batch_id % 500 == 0:\n",
    "                print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n",
    "            avg_loss.backward()\n",
    "            opt.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "\n",
    "        # evaluate model after one epoch\n",
    "        model.eval()\n",
    "        accuracies = []\n",
    "        losses = []\n",
    "        for batch_id in range(len(test_sents) // batch_size):\n",
    "            x_data = test_sents[(batch_id * batch_size):((batch_id+1)*batch_size)]\n",
    "            y_data = test_labels[(batch_id * batch_size):((batch_id+1)*batch_size)]\n",
    "        \n",
J
jzhang533 已提交
325 326
    "            sent = paddle.to_tensor(x_data)\n",
    "            label = paddle.to_tensor(y_data)\n",
J
jzhang533 已提交
327 328 329
    "\n",
    "            logits = model(sent)\n",
    "            loss = paddle.nn.functional.softmax_with_cross_entropy(logits, label)\n",
J
jzhang533 已提交
330
    "            acc = paddle.metric.accuracy(logits, label)\n",
J
jzhang533 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
    "            \n",
    "            accuracies.append(acc.numpy())\n",
    "            losses.append(loss.numpy())\n",
    "        \n",
    "        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)\n",
    "        print(\"[validation] accuracy/loss: {}/{}\".format(avg_acc, avg_loss))\n",
    "        \n",
    "        model.train()\n",
    "        \n",
    "model = MyNet()\n",
    "train(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The End\n",
    "\n",
    "可以看到,在这个数据集上,经过两轮的迭代可以得到86%左右的准确率。你也可以通过调整网络结构和超参数,来获得更好的效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "cifar-10-cnn.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}