{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 用N-Gram模型在莎士比亚文集中训练word embedding\n",
    "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n",
    "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n",
    "本示例在莎士比亚文集上实现了trigram。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境\n",
    "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.0.0'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import paddle\n",
    "paddle.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集&&相关参数\n",
    "训练数据集采用了莎士比亚文集,[下载](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt),保存为txt格式即可。
\n",
    "context_size设为2,意味着是trigram。embedding_dim设为256。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-09-09 14:58:26--  https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\n",
      "正在解析主机 ocw.mit.edu (ocw.mit.edu)... 151.101.110.133\n",
      "正在连接 ocw.mit.edu (ocw.mit.edu)|151.101.110.133|:443... 已连接。\n",
      "已发出 HTTP 请求,正在等待回应... 200 OK\n",
      "长度:5458199 (5.2M) [text/plain]\n",
      "正在保存至: “t8.shakespeare.txt”\n",
      "\n",
      "t8.shakespeare.txt  100%[===================>]   5.21M  94.1KB/s  用时 70s       \n",
      "\n",
      "2020-09-09 14:59:38 (75.7 KB/s) - 已保存 “t8.shakespeare.txt” [5458199/5458199])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dim = 256\n",
    "context_size = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of text: 5458199 characters\n"
     ]
    }
   ],
   "source": [
    "# 文件路径\n",
    "path_to_file = './t8.shakespeare.txt'\n",
    "test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
    "\n",
    "# 文本长度是指文本中的字符个数\n",
    "print ('Length of text: {} characters'.format(len(test_sentence)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 去除标点符号\n",
    "因为标点符号本身无实际意义,用`string`库中的punctuation,完成英文符号的替换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n"
     ]
    }
   ],
   "source": [
    "from string import punctuation\n",
    "process_dicts={i:'' for i in punctuation}\n",
    "print(process_dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28343\n"
     ]
    }
   ],
   "source": [
    "punc_table = str.maketrans(process_dicts)\n",
    "test_sentence = test_sentence.translate(punc_table)\n",
    "test_sentence = test_sentence.lower().split()\n",
    "vocab = set(test_sentence)\n",
    "print(len(vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据预处理\n",
    "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[['this', 'is'], 'the'], [['is', 'the'], '100th'], [['the', '100th'], 'etext']]\n"
     ]
    }
   ],
   "source": [
    "trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n",
    "           for i in range(len(test_sentence) - 2)]\n",
    "\n",
    "word_to_idx = {word: i for i, word in enumerate(vocab)}\n",
    "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n",
    "# 看一下数据集\n",
    "print(trigram[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建`Dataset`类 加载数据\n",
    "用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "batch_size = 256\n",
    "paddle.disable_static()\n",
    "class TrainDataset(paddle.io.Dataset):\n",
    "    def __init__(self, tuple_data):\n",
    "        self.tuple_data = tuple_data\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.tuple_data[idx][0]\n",
    "        label = self.tuple_data[idx][1]\n",
    "        data = np.array(list(map(lambda w: word_to_idx[w], data)))\n",
    "        label = np.array(word_to_idx[label])\n",
    "        return data, label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.tuple_data)\n",
    "train_dataset = TrainDataset(trigram)\n",
    "train_loader = paddle.io.DataLoader(train_dataset,places=paddle.CPUPlace(), return_list=True,\n",
    "                                    shuffle=True, batch_size=batch_size, drop_last=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 组网&训练\n",
    "这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "hidden_size = 1024\n",
    "class NGramModel(paddle.nn.Layer):\n",
    "    def __init__(self, vocab_size, embedding_dim, context_size):\n",
    "        super(NGramModel, self).__init__()\n",
    "        self.embedding = paddle.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
    "        self.linear1 = paddle.nn.Linear(context_size * embedding_dim, hidden_size)\n",
    "        self.linear2 = paddle.nn.Linear(hidden_size, len(vocab))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x)\n",
    "        x = paddle.reshape(x, [-1, context_size * embedding_dim])\n",
    "        x = self.linear1(x)\n",
    "        x = paddle.nn.functional.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义`train()`函数,对模型进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, batch_id: 0, loss is: [10.252193]\n",
      "epoch: 0, batch_id: 500, loss is: [6.894636]\n",
      "epoch: 0, batch_id: 1000, loss is: [6.849346]\n",
      "epoch: 0, batch_id: 1500, loss is: [6.931605]\n",
      "epoch: 0, batch_id: 2000, loss is: [6.6860313]\n",
      "epoch: 0, batch_id: 2500, loss is: [6.2472367]\n",
      "epoch: 0, batch_id: 3000, loss is: [6.8818874]\n",
      "epoch: 0, batch_id: 3500, loss is: [6.941615]\n",
      "epoch: 1, batch_id: 0, loss is: [6.3628616]\n",
      "epoch: 1, batch_id: 500, loss is: [6.2065206]\n",
      "epoch: 1, batch_id: 1000, loss is: [6.5334334]\n",
      "epoch: 1, batch_id: 1500, loss is: [6.5788]\n",
      "epoch: 1, batch_id: 2000, loss is: [6.352103]\n",
      "epoch: 1, batch_id: 2500, loss is: [6.6272373]\n",
      "epoch: 1, batch_id: 3000, loss is: [6.801074]\n",
      "epoch: 1, batch_id: 3500, loss is: [6.2274427]\n"
     ]
    }
   ],
   "source": [
    "vocab_size = len(vocab)\n",
    "epochs = 2\n",
    "losses = []\n",
    "def train(model):\n",
    "    model.train()\n",
    "    optim = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\n",
    "    for epoch in range(epochs):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            x_data = data[0]\n",
    "            y_data = data[1]\n",
    "            predicts = model(x_data)\n",
    "            y_data = paddle.reshape(y_data, ([-1, 1]))\n",
    "            loss = paddle.nn.functional.softmax_with_cross_entropy(predicts, y_data)\n",
    "            avg_loss = paddle.mean(loss)\n",
    "            avg_loss.backward()\n",
    "            if batch_id % 500 == 0:\n",
    "                losses.append(avg_loss.numpy())\n",
    "                print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy())) \n",
    "            optim.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "model = NGramModel(vocab_size, embedding_dim, context_size)\n",
    "train(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 打印loss下降曲线\n",
    "通过可视化loss的曲线,可以看到模型训练的效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       ""
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "%matplotlib inline\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测\n",
    "用训练好的模型进行预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the input words is: of, william\n",
      "the predict words is: shakespeare\n",
      "the true words is: shakespeare\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "def test(model):\n",
    "    model.eval()\n",
    "    # 从最后10组数据中随机选取1个\n",
    "    idx = random.randint(len(trigram)-10, len(trigram)-1)\n",
    "    print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n",
    "    x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n",
    "    x_data = paddle.to_tensor(np.array(x_data))\n",
    "    predicts = model(x_data)\n",
    "    predicts = predicts.numpy().tolist()[0]\n",
    "    predicts = predicts.index(max(predicts))\n",
    "    print('the predict words is: ' + idx_to_word[predicts])\n",
    "    y_data = trigram[idx][1]\n",
    "    print('the true words is: ' + y_data)\n",
    "test(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}