{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 用N-Gram模型在莎士比亚文集中训练word embedding\n",
    "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n",
    "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n",
    "本示例在莎士比亚文集上实现了trigram。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境\n",
    "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.0.0'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import paddle\n",
    "paddle.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集&&相关参数\n",
    "训练数据集采用了莎士比亚文集,[下载](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt),保存为txt格式即可。
\n",
    "context_size设为2,意味着是trigram。embedding_dim设为256。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-09-09 14:58:26--  https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\n",
      "正在解析主机 ocw.mit.edu (ocw.mit.edu)... 151.101.110.133\n",
      "正在连接 ocw.mit.edu (ocw.mit.edu)|151.101.110.133|:443... 已连接。\n",
      "已发出 HTTP 请求,正在等待回应... 200 OK\n",
      "长度:5458199 (5.2M) [text/plain]\n",
      "正在保存至: “t8.shakespeare.txt”\n",
      "\n",
      "t8.shakespeare.txt  100%[===================>]   5.21M  94.1KB/s  用时 70s       \n",
      "\n",
      "2020-09-09 14:59:38 (75.7 KB/s) - 已保存 “t8.shakespeare.txt” [5458199/5458199])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dim = 256\n",
    "context_size = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of text: 5458199 characters\n"
     ]
    }
   ],
   "source": [
    "# 文件路径\n",
    "path_to_file = './t8.shakespeare.txt'\n",
    "test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
    "\n",
    "# 文本长度是指文本中的字符个数\n",
    "print ('Length of text: {} characters'.format(len(test_sentence)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 去除标点符号\n",
    "因为标点符号本身无实际意义,用`string`库中的punctuation,完成英文符号的替换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n"
     ]
    }
   ],
   "source": [
    "from string import punctuation\n",
    "process_dicts={i:'' for i in punctuation}\n",
    "print(process_dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28343\n"
     ]
    }
   ],
   "source": [
    "punc_table = str.maketrans(process_dicts)\n",
    "test_sentence = test_sentence.translate(punc_table)\n",
    "test_sentence = test_sentence.lower().split()\n",
    "vocab = set(test_sentence)\n",
    "print(len(vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据预处理\n",
    "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[['this', 'is'], 'the'], [['is', 'the'], '100th'], [['the', '100th'], 'etext']]\n"
     ]
    }
   ],
   "source": [
    "trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n",
    "           for i in range(len(test_sentence) - 2)]\n",
    "\n",
    "word_to_idx = {word: i for i, word in enumerate(vocab)}\n",
    "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n",
    "# 看一下数据集\n",
    "print(trigram[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建`Dataset`类 加载数据\n",
    "用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "batch_size = 256\n",
    "paddle.disable_static()\n",
    "class TrainDataset(paddle.io.Dataset):\n",
    "    def __init__(self, tuple_data):\n",
    "        self.tuple_data = tuple_data\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.tuple_data[idx][0]\n",
    "        label = self.tuple_data[idx][1]\n",
    "        data = np.array(list(map(lambda w: word_to_idx[w], data)))\n",
    "        label = np.array(word_to_idx[label])\n",
    "        return data, label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.tuple_data)\n",
    "train_dataset = TrainDataset(trigram)\n",
    "train_loader = paddle.io.DataLoader(train_dataset,places=paddle.CPUPlace(), return_list=True,\n",
    "                                    shuffle=True, batch_size=batch_size, drop_last=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 组网&训练\n",
    "这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "hidden_size = 1024\n",
    "class NGramModel(paddle.nn.Layer):\n",
    "    def __init__(self, vocab_size, embedding_dim, context_size):\n",
    "        super(NGramModel, self).__init__()\n",
    "        self.embedding = paddle.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
    "        self.linear1 = paddle.nn.Linear(context_size * embedding_dim, hidden_size)\n",
    "        self.linear2 = paddle.nn.Linear(hidden_size, len(vocab))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x)\n",
    "        x = paddle.reshape(x, [-1, context_size * embedding_dim])\n",
    "        x = self.linear1(x)\n",
    "        x = paddle.nn.functional.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义`train()`函数,对模型进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, batch_id: 0, loss is: [10.252193]\n",
      "epoch: 0, batch_id: 500, loss is: [6.894636]\n",
      "epoch: 0, batch_id: 1000, loss is: [6.849346]\n",
      "epoch: 0, batch_id: 1500, loss is: [6.931605]\n",
      "epoch: 0, batch_id: 2000, loss is: [6.6860313]\n",
      "epoch: 0, batch_id: 2500, loss is: [6.2472367]\n",
      "epoch: 0, batch_id: 3000, loss is: [6.8818874]\n",
      "epoch: 0, batch_id: 3500, loss is: [6.941615]\n",
      "epoch: 1, batch_id: 0, loss is: [6.3628616]\n",
      "epoch: 1, batch_id: 500, loss is: [6.2065206]\n",
      "epoch: 1, batch_id: 1000, loss is: [6.5334334]\n",
      "epoch: 1, batch_id: 1500, loss is: [6.5788]\n",
      "epoch: 1, batch_id: 2000, loss is: [6.352103]\n",
      "epoch: 1, batch_id: 2500, loss is: [6.6272373]\n",
      "epoch: 1, batch_id: 3000, loss is: [6.801074]\n",
      "epoch: 1, batch_id: 3500, loss is: [6.2274427]\n"
     ]
    }
   ],
   "source": [
    "vocab_size = len(vocab)\n",
    "epochs = 2\n",
    "losses = []\n",
    "def train(model):\n",
    "    model.train()\n",
    "    optim = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\n",
    "    for epoch in range(epochs):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            x_data = data[0]\n",
    "            y_data = data[1]\n",
    "            predicts = model(x_data)\n",
    "            y_data = paddle.reshape(y_data, ([-1, 1]))\n",
    "            loss = paddle.nn.functional.softmax_with_cross_entropy(predicts, y_data)\n",
    "            avg_loss = paddle.mean(loss)\n",
    "            avg_loss.backward()\n",
    "            if batch_id % 500 == 0:\n",
    "                losses.append(avg_loss.numpy())\n",
    "                print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy())) \n",
    "            optim.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "model = NGramModel(vocab_size, embedding_dim, context_size)\n",
    "train(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 打印loss下降曲线\n",
    "通过可视化loss的曲线,可以看到模型训练的效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd3xV9f3H8dcnBMgAkgCBQIKAgogiIKQoKo46ceHAqq1W/dX5k4qttdUO7c8uOy2Oaqm1rioqLqyjOKqixRH2Hg4gYYWVQPb4/P64NxpiEMjNzb2c+34+Hnnk3nNP7vkw8s7J93zO92vujoiIBFdSrAsQEZHoUtCLiAScgl5EJOAU9CIiAaegFxEJuORYF9BU9+7dvV+/frEuQ0RknzJr1qxN7p7d3GtxF/T9+vWjoKAg1mWIiOxTzGzVrl7T0I2ISMAp6EVEAk5BLyIScAp6EZGAU9CLiAScgl5EJOAU9CIiAReYoC/aVsEfpy9j1eayWJciIhJXAhP0JeU13P3mShatLY11KSIicSUwQZ+blQpA4dbyGFciIhJfAhP0Gant6ZySTOHWiliXIiISVwIT9AB5WWkUKehFRHYSsKBP1Rm9iEgTgQr63MxUCreWowXPRUS+EKigz8tKpay6jpKKmliXIiISNwIW9GkAGr4REWkkYEGvFksRkaZ2G/Rm9qCZbTSzhY22dTWz18xsRfhz1i6+9tLwPivM7NLWLLw5XwS9zuhFRBrsyRn9Q8CpTbbdDLzh7gOBN8LPd2JmXYHbgMOBUcBtu/qB0FoyUtvTqaN66UVEGttt0Lv7O8CWJpvHAQ+HHz8MnN3Ml54CvObuW9x9K/AaX/6B0arMLNx5o6AXEWnQ0jH6nu6+Lvx4PdCzmX1ygTWNnheGt0VVXlYqRdsU9CIiDSK+GOuhpvWIGtfN7CozKzCzguLi4ojqCd00pYuxIiINWhr0G8ysF0D488Zm9ikC+jR6nhfe9iXuPtnd8909Pzs7u4UlheRmpbK9sla99CIiYS0N+mlAQxfNpcALzezzb+BkM8sKX4Q9Obwtqhp66TXnjYhIyJ60Vz4BzAQGmVmhmX0HuAM4ycxWACeGn2Nm+Wb2AIC7bwF+AXwU/rg9vC2q1EsvIrKz5N3t4O4X7eKlE5rZtwC4otHzB4EHW1xdC+RmqpdeRKSxQN0ZC9A1vQOp7dup80ZEJCxwQW9m6rwREWkkcEEPoc4bDd2IiIQEMuh105SIyBcCGvRpbCuvYXuleulFRAIZ9A2dNzqrFxEJaNA39NLrpikRkcAGvVaaEhFpEMig796pAx2TkzR0IyJCQIPezMItluqlFxEJZNBDaPhGQzciIgEO+tzMVF2MFREhwEGfl5XK5rJqyqtrY12KiEhMBTroQS2WIiKBD/pCdd6ISIILcNCrl15EBCIMejObaGYLzWyRmd3QzOvHmVmJmc0Nf9wayfH2RnanjnRol6QWSxFJeLtdYWpXzGwIcCUwCqgGXjWzf7n7yia7znD3MyKosUWSkozemSkaoxeRhBfJGf1g4AN3L3f3WuBt4NzWKat1qJdeRCSyoF8IjDGzbmaWBpwG9Glmv9FmNs/MXjGzQ5p7IzO7yswKzKyguLg4gpJ2lqcFSEREWh707r4E+C0wHXgVmAvUNdltNtDX3YcBdwPP7+K9Jrt7vrvnZ2dnt7SkL8nNTGXTjioqa5qWJSKSOCK6GOvuf3f3ke5+DLAVWN7k9VJ33xF+/DLQ3sy6R3LMvZHXVfPSi4hE2nXTI/x5P0Lj8483eT3HzCz8eFT4eJsjOebeUIuliEgEXTdhz5hZN6AGuM7dt5nZNQDufj8wHrjWzGqBCuBCd/cIj7nHPl9pSkEvIgksoqB39zHNbLu/0eN7gHsiOUYkenZJITnJ1EsvIgktsHfGArRLMnpnqvNGRBJboIMewtMV62KsiCSwwAd9nlaaEpEElwBBn8aG0iqqatVLLyKJKfBBnxuernjdtsoYVyIiEhuBD/rP56XXBVkRSVAJFPQapxeRxBT4oM/pkkK7JFPnjYgkrMAHfXK7JHK6pGjoRkQSVuCDHtRiKSKJLSGCPjcrVfPdiEjCSoigz8tKY31pJdW19bEuRUSkzSVI0KdS77C+RL30IpJ4EiPow9MVF27TOL2IJJ7ECHotQCIiCSzSFaYmmtlCM1tkZjc087qZ2V1mttLM5pvZiEiO11I5GSkkmYJeRBJTi4PezIYAVwKjgGHAGWY2oMluY4GB4Y+rgPtaerxIdEhOomeXFHXeiEhCiuSMfjDwgbuXu3st8DahdWMbGwc84iHvA5lm1iuCY7aYeulFJFFFEvQLgTFm1s3M0oDTgD5N9skF1jR6XhjethMzu8rMCsysoLi4OIKSdi0vK01DNyKSkFoc9O6+BPgtMB14FZgLtGjSd3ef7O757p6fnZ3d0pK+Um5mKutLK6mtUy+9iCSWiC7Guvvf3X2kux8DbAWWN9mliJ3P8vPC29pcXlYqdfXO+lL10otIYom066ZH+PN+hMbnH2+yyzTg2+HumyOAEndfF8kxW0otliKSqJIj/PpnzKwbUANc5+7bzOwaAHe/H3iZ0Nj9SqAcuDzC47VYw0pT6rwRkUQTUdC7+5hmtt3f6LED10VyjNbSOzMF0Bm9iCSehLgzFqBjcjt6dumoFksRSTgJE/QQ6rzRSlMikmgSKujVSy8iiSjBgj6VtdsqqKv3WJciItJmEiroc7NSqa13Nm5XL72IJI6ECnr10otIIkqwoA8vQKLOGxFJIAkV9LmZumlKRBJPQgV9Svt2dO/UUUM3IpJQEirooWFeegW9iCSOhAv63CzdNCUiiSXhgj4vK5WirRXUq5deRBJEAgZ9GtV19RTvqIp1KSIibSLxgj6zocVSwzcikhgSL+jVSy8iCSbhgj43S2f0IpJYIl1K8HtmtsjMFprZE2aW0uT1y8ys2Mzmhj+uiKzcyKV1SKZregd13ohIwmhx0JtZLnA9kO/uQ4B2wIXN7Pqkuw8PfzzQ0uO1JvXSi0giiXToJhlINbNkIA1YG3lJ0RcKeo3Ri0hiaHHQu3sR8AdgNbAOKHH36c3sep6ZzTezqWbWp7n3MrOrzKzAzAqKi4tbWtIey80M9dKHlrQVEQm2SIZusoBxQH+gN5BuZhc32e1FoJ+7DwVeAx5u7r3cfbK757t7fnZ2dktL2mN5WWlU1dazaUd11I8lIhJrkQzdnAh86u7F7l4DPAsc2XgHd9/s7g13Jj0AjIzgeK2mocVSF2RFJBFEEvSrgSPMLM3MDDgBWNJ4BzPr1ejpWU1fj5Vc9dKLSAJJbukXuvsHZjYVmA3UAnOAyWZ2O1Dg7tOA683srPDrW4DLIi85crm6O1ZEEkiLgx7A3W8Dbmuy+dZGr98C3BLJMaKhc0p7MtPaawESEUkICXdnbIPcTLVYikhiSNig101TIpIoEjboczPTKNqmXnoRCb6EDfq8rFTKq+vYWl4T61JERKIqoYMe1GIpIsGXsEHf0EuvzhsRCbqEDfq8rDRAvfQiEnwJG/QZqe3pnJKsoRsRCbyEDXoIz2Kp+W5EJOASOujzstI0dCMigZfgQR+6aUq99CISZAkf9DuqaimtqI11KSIiUZPwQQ+wRhdkRSTAEjzo1WIpIsGX0EHfMC+9Om9EJMgiCnoz+56ZLTKzhWb2hJmlNHm9o5k9aWYrzewDM+sXyfFaW2Zae9I7tFMvvYgEWiSLg+cC1wP57j4EaAdc2GS37wBb3X0AcCfw25YeLxrMTC2WIhJ4kQ7dJAOpZpYMpAFrm7w+Dng4/HgqcEJ4fdm4kZuVqvluRCTQWhz07l4E/IHQIuHrgBJ3n95kt1xgTXj/WqAE6Nb0vczsKjMrMLOC4uLilpbUIqFeeg3diEhwRTJ0k0XojL0/0BtIN7OLW/Je7j7Z3fPdPT87O7ulJbVIXlYqpZW1lFZqXnoRCaZIhm5OBD5192J3rwGeBY5ssk8R0AcgPLyTAWyO4JitLjcz1GKp4RsRCapIgn41cISZpYXH3U8AljTZZxpwafjxeOBNj7P5Br5YgERBLyLBFMkY/QeELrDOBhaE32uymd1uZmeFd/s70M3MVgLfB26OsN5Wp5WmRCTokiP5Yne/DbityeZbG71eCZwfyTGirWt6B1LaJ2noRkQCK6HvjAX10otI8CV80EO4xXKbhm5EJJgU9IRXmtIZvYgElIKe0CyWW8tr2FGleelFJHgU9HzReaOzehEJIgU9ofluAIo0Ti8iAaSgRzdNiUiwKeiB7E4d6ZicpKAXkUBS0BPqpVfnjYgElYI+LFfTFYtIQCnow3R3rIgElYI+LC8rlc1l1VRU18W6FBGRVqWgD8tTi6WIBJSCPqwh6Ndo+EZEAkZBH6aVpkQkqBT0YT06d6R9O9MFWREJnEgWBx9kZnMbfZSa2Q1N9jnOzEoa7XPrrt4v1pKSQr30arEUkaBp8QpT7r4MGA5gZu0ILQT+XDO7znD3M1p6nLaUm5VK0Tad0YtIsLTW0M0JwMfuvqqV3i8m8jLVSy8iwdNaQX8h8MQuXhttZvPM7BUzO6S5HczsKjMrMLOC4uLiVipp7+VlpVK8vYrKGvXSi0hwRBz0ZtYBOAt4upmXZwN93X0YcDfwfHPv4e6T3T3f3fOzs7MjLanFGqYrXqvhGxEJkNY4ox8LzHb3DU1fcPdSd98Rfvwy0N7MurfCMaMiLyvUYqnhGxEJktYI+ovYxbCNmeWYmYUfjwofb3MrHDMqNC+9iARRi7tuAMwsHTgJuLrRtmsA3P1+YDxwrZnVAhXAhe7ukRwzmnp2SSE5yTQNgogESkRB7+5lQLcm2+5v9Pge4J5IjtGW2iUZvTJTdEYvIoGiO2ObUIuliASNgr6J3CytNCUiwaKgbyIvK5UN2yupqlUvvYgEg4K+ibysNNxh3bbKWJciItIqFPRN5GY2LECi4RsRCQYFfRNf9NKrxVJEgkFB30SvjBTaJWleehEJDgV9E8ntksjpkqLOGxEJDAV9M3KzUnVGLyKBoaBvRp4WIBGRAFHQNyMvM5V1JRXU1NXHuhQRkYgp6JuRl5VGvcP6EvXSi8i+T0HfDE1XLCJBoqBvRq566UUkQBT0zeiVkYqZzuhFJBhaHPRmNsjM5jb6KDWzG5rsY2Z2l5mtNLP5ZjYi8pKjr0NyuJdenTciEgAtXnjE3ZcBwwHMrB1QBDzXZLexwMDwx+HAfeHPcS83M1VDNyISCK01dHMC8LG7r2qyfRzwiIe8D2SaWa9WOmZU5emmKREJiNYK+gtpfoHwXGBNo+eF4W07MbOrzKzAzAqKi4tbqaTI5GWlsb6kklr10ovIPi7ioDezDsBZwNMtfQ93n+zu+e6en52dHWlJrSI3K5XaemfD9qpYlyIiEpGIFgcPGwvMdvcNzbxWBPRp9DwvvC3ufd5Lv6X88znq483G0kqKd1RxcK8umFmsy9knVNbU8dnmMj7bVMZnm8v5bFMZG0orueW0wRzYs3OsyxOJitYI+otoftgGYBowwcymELoIW+Lu61rhmFGXl5UGxMcCJO7OhtIqFhSVsKCohIXhj43h3zZG79+N/xt3iIIqrLKmjtVbyvl0U0Ogl/HZpnI+21zGuiZ3O3fv1IHKmnomPD6baROOJqV9uxhVLRI9EQW9maUDJwFXN9p2DYC73w+8DJwGrATKgcsjOV5b6pWRAsAbSzfSNb0DvTJSyclIoUtKclTPnt2dtSWVLCgsYdHaL4J9045qAJIMDsjuxNEDujMkN4N6d+5+cyVjJ83g0tH9uOGkgXRJaR+1+uJFVW0da7aU8+mm0Fn5p+Gz9FWby1lbUoH7F/t2Te9Av25pjD6gG/27pdO3ezr9u6XTr3sanVPa8/byYi598EN+/fISbh83JHZ/KJEoMW/8HREH8vPzvaCgINZlAHD6XTNYtLZ0p21pHdqRk5FCr4wUcrqkhj43PM9IoVdGKllp7ffoh4G7U7i1goXhM/UFRSUsWlvKlrJQqLdLMgb26MSQ3AyG9O7CoXkZDO7VhbQOO/983lJWzR+mL+OJD1fTLb0DPzz1IMaPyCMpKXjDOf9ZtpHbXljEmq3lO4V5Zlp7+nVLp3/3dPqFQzz0OZ2M1N3/4PvlvxbzwLuf8vdL8zlhcM8o/glEosPMZrl7frOvKeh3rbaung3bq1hfUsG6kkrWl1Q2+lzB+pJKNmyvoq5+57/DhhuudvoB0CWFnIxU6ur9i+GXtSVsK68BIDnJOLBnZ4bkduHQ3AyG5IZCfW+GEhYWlXDrCwuZvXobw/tkcvu4Qxial9mqfyextKWsmpP+9DYZae05Y2hv+ofDvH/3dDLTOkT03lW1dZxz739ZX1rJqxPH0KNLSitVLdI2FPRRVFfvbNpRFf4B0MwPhNIKNpRUUd2oTbN9O2NQTufPA31I7wwG5XRulfHh+nrnuTlF/OaVpWwuq+KC/D7cdMogunXqGPF7x9rEKXN4ecE6Xvzu0RyU06XV33/lxu2ccfe7fK1fVx6+fFQgfyOS4PqqoG+Ni7EJrV2S0bNLCj27pECf5s+e3Z0tZdWfXwg8sGdnOiRHZ5qhpCTjvJF5nHxITya9voKH/vsZLy9Yx40nD+Jbh+9Hcrt9c3qj1xdv4IW5a7nhxIFRCXmAAT06c+sZh/Dj5xbw4HufcsWY/aNyHJG2tm9+1+9jzIxunTqGzt5zM6IW8o11TmnPT884mFdvGMPQvExum7aIM+5+lw8+2Rz1Y7e2kooafvL8Ag7K6cz/Hjcgqse6aFQfTjmkJ799dSkLi0qieizZNxVvr+KeN1dw+l0zeHfFpliXs0c0dJMA3J1XF67nly8toWhbBWcN682PTxtMTsa+MQ79o6nzeXrWGp6/7qg2ueawtayasZNmkNaxHf/67tFfuvgticfdmb16K4/MXMXLC9ZRU+d0TkmmY3ISr0w8huzOsR8a/aqhG53RJwAzY+yhvXj9+8dy/QkDeXXRer7+x7f4y1srqaqti3V5X2nGimKeLFjDVccc0GYXlrPSO/Cnbwzj001l/OJfi9vkmBKfKqrrmPLhak6/613Ou28mby7ZyMVH9OWNG4/lmWuPZHtlLTc+PY/6+vg6YW5KZ/QJaPXmcn7x0mJeW7yB/t3TufXMgzl+UI9Yl/UlZVW1nHznO3RMTuLliWPa/Gam3766lPve+pj7Lx7BqUP2ibn4pJV8tqmMR99fxdMFayitrOWgnM5cMrovZw/PJb3jF7/hPfr+Kn72/EJ+ctpgrjwmttd0dDFWdrJftzT+9u183lq2kdtfXMzl//iIEwf34GdnHEzfbumxLu9zv//3MtaWVPD01aNjcsfq9048kPdWbuJHzyxgWJ9MemXE51QY0jrq6p23lm3kkZmreHt5MclJxqlDcvj26H58rV9Ws/fGXHz4fsxYXszv/r2UI/bvxqF5GTGofPd0Rp/gqmvrefC9T7nrjRXU1jtXH7M/1x0/IOZTAXz02RbOv38mlx3Zj5+fdUjM6vh0Uxmn3zWDoXkZ/POKI2inlsvA2VpWzZMFa3js/VUUbq2gR+eOfOvwvlw0qs8e3U/RcE0ntUPomk7jM/62pD562a31JZX85pUlvDB3LWMGdueBS/PpmBybsK+sqeO0STOorqvn3zccE7NvnAZPF6zhpqnz+eGpg6Le9SNtZ37hNh6ZuYpp89ZSXVvP4f278u3R/Tj5kJ6038s25Jkfb+abD7zPeSPy+MP5w6JU8VfT0I3sVk5GCpMuPIyjB3Tnpqnzuf6JOdz7zREx6bu/8/XlfLKpjH9ecXjMQx5g/Mg83lpezJ+mL+fIA7ozfBf3S0j8q6yp46X563jk/VXMW7ONtA7t+EZ+Hpcc0Y9BOS2fFHD0Ad2YcPwA7n5zJcccmM1Zw3q3YtWRi/13kcSV8/P7UFZVy89fXMxNU+fzx/OHtekdovPWbONv73zChV/rw1EDurfZcb+KmfHrsw9l7uptTJwyh5euH0OnOPgBJHtu844q/jbjU54qWMOWsmoOyE7n/846hHNH5NK5lSYBnHjCQN5buYmfPLuAw/pk0qdrWqu8b2tQe6V8yWVH9eemUwbx3Jwibp22kLYa3quuredHz8ynR+cUfnz64DY55p7KSGvPnRcMZ82Wcn4+bVGsy5G98NL8dZx05ztMfudjvtYvi39ecTivf/9YLj2yX6uFPEByuyQmXXgYANdPmUNNHK1Op6CXZv3vcQdwzbEH8Nj7q7nj1aVtEvZ/eWslS9dv51fnDInLqZZH9e/KhK8PZOqsQqbNWxvrcuJKvF3rg9AdrNc+NovrHp9NXlYqr0w8hr9eks9RA7pHbarxPl3T+PW5hzJn9TYmvb4iKsdoCf3+Kc0yM3506iDKqmr569uf0LljMhO+PjBqx1u6vpR73lzJ2cN7x/U0wdd/fQDvrijmJ8/F36/n0VJVW8fG0tDEfQ2ztn4xaV9oMr9NO6o5ekB3bh57EIN7RWcuoj3l7kybt5afT1tEWVUdPzr1IK4c07/NrjedOaw37ywv5t63VnLUgO6MPqBbmxz3q0TUdWNmmcADwBDAgf9x95mNXj8OeAH4NLzpWXe//aveU1038aW+3vnB0/N4dk4Rt515MJcf1b/Vj1FbV8+59/2Xoq0VvPb9Y+maHtmUw9G2Zks5YyfN4KCczky56oh9dqI4CN35ub60mQAvqWR9aWhbw6I3jXXumExOo7UY0jsm8+zsIkorazhvRB7fP+lAesdgCc6N2yv56XMLmb54A8P7ZPL78UMZGIOV18qqajnz7ncpr67jlYljyGqD/9PR7LqZBLzq7uPDi4Q3d3ozw93PiPA4EiNJScbvxg+lvLqO/3txMekdkvnG1/rs/gv3wgPvfsr8whLu/eaIuA95CP16/qtzhjBxylzu/c/HTDwxer/pRMM7y4v5/b+XsWZr+efrITSWmdaenC6hAD80N3OnxXV6ZYRmam1ubPuGEw7kL2+t5B///YwX563l8qP6c+1xB+zRwi+Rcneen1vEz6ctpqKmjh+fdhDfOXr/mN33kN4xmbsuOoxz/vIeP3xmPpMvGRnTdZ1bfEZvZhnAXGB/38WbhM/of7A3Qa8z+vhUVVvHlY/M4t0Vxdx10WGcMbR12sc+Lt7B2EkzOH5QNvdfHNtvhr31vSfn8sLcIp66ejT5/brGupzdqq937nv7Y/4wfRn7d0/nqAHdPw/wnl1Cq6PldEkhtUNk908Ubavgj9OX8dycIjJS2zPh+AFcMrpv1O7L2FBayU+eW8DrSzYysm8Wvxs/lAOyO0XlWHvrgRmf8MuXlvCLs4dwyRF9o3qsqNwwZWbDgcnAYmAYMAuY6O5ljfY5DngGKATWEgr9r2xZUNDHr4rqOi598ENmr97K376dz/EHRTY/Tn29c8HkmSzfsIPXvnfMPreq0/bKGk67awb19fDKDWPi8gJyg+2VNdz41DymL97AuOG9uePcoREH+u4sWlvCHa8sZcaKTeRlpXLTKYM4c2jvVmvXdXeemV3E7S8uoqq2nptOGcTlR/WPq7uX6+udyx76iA8+2cy0CUdH1Ku/O9GavTIZGAHc5+6HAWXAzU32mQ30dfdhwN3A87so8CozKzCzguLi4ghKkmhK7dCOBy7LZ3CvLlzz2CxmfhzZ3PaPvr+Kjz7bys/OOHifC3kIzfk/6cLDWF8aGheOx84TgBUbtjPunvd4c+lGbjvzYP58wfCohzzAIb0zePQ7h/Pod0bRJaU9E6fMZdy97/HfjyOfw31dSQWXP/QRP3h6HoNyOvPqDcdwxZjYDdXsSlKS8cfzh9E5JZnvPjGbyprYzBYbSdAXAoXu/kH4+VRCwf85dy919x3hxy8D7c3sS3fBuPtkd8939/zs7OwISpJo65LSnof/ZxT7dU3jioc/Ys7qrS16nzVbyvntq0s59sBszhuR28pVtp0R+2VxwwkDmTZvLc/NKYp1OV/y0vx1jLv3PUora3n8yiO4/Kj+bT48NmZgNv/67tHcecEwtpRV882/fcDl//iQpetL9/q93J2nPlrDyX96hw8+2cJtZx7Mk1eNpn/3+JmMr6nszh354zeGs3zDDn710pKY1NDioHf39cAaMxsU3nQCoWGcz5lZjoX/V5nZqPDx9r0ljmQnXdM78M8rDqd7545c9o+PWLJu775h3Z1bnl2AAb8+99B9aly+Of97/ABG9evKz55fyKrNZbv/gjZQW1fPr19ewnWPz+agnM68dP3RjOofu+sISUnGOYfl8caNx/KT0wYza9VWxk6awU1Pz2NdScUevUfRtgou/cdH/PCZ+Rzcuwuv3jCGy4/qv0+s7XvsgdlcOaY/j76/iumL1rf58SNtrxxOqL2yA/AJcDlwAYC7329mE4BrgVqgAvi+u//3q95TY/T7jjVbyjn//pnU1tfz1NWj2X8PL4A9+dFqfvTMAn559hAujvIFqrZStK2CsX9+h/7ZnZh6zei9nhSrNW3eUcWEx+cw85PNfHt0X356+sFtsnzl3thWXs1f3vqYh977DDP4ztH9uea4A5q9zuHuTPloDb96aQn17tw89iAuPrzvPhHwjVXX1nPufe9RuLWCVyaOafVprzV7pUTNyo07uOCvM+mYnMRT14wmL+urbyBaX1LJSXe+zSG9u/D4FUfsc9+sX+Wl+eu47vHZTDh+AD84ZdDuvyAK5q3ZxrWPzWJzWTW/PudQzhuZF5M69tSaLeX86bXlPDeniKy09nz36wO5+Ii+n/9gKtxazs3PLODdlZsYvX83fjd+6D59k9onxTs44+53ozLttYJeomrx2lIunDyTrukdeOqa0fTo3PyFVXfnykcKeHflJv59wzFxtchJa/nh1Hk8PauQP18wnNMP7dWmN1M98eFqbnthET26dOT+i0cyJDc+F8FozsKiUIfOuys3sV/XNH5wyiBKK2r4zcuhMe1bThvMN0ftF4gTg4Zpr286ZRDXHd96014r6CXqZq3ayiV//4A+WWk8efURZKZ9+canF+YWMXHKXH56+mCuGBPbZdeipayqlnP+8h7LN+ygR+eOnDMil/NH5jGgR/Ta6ipr6vj5tEVM+WgNxxyYzaQLhrfJnZjR8M7yYjJIGyYAAAdTSURBVH7zytLPr/scPaA7d5x36G5/U9yXuDvXT5nLywvW8dTVoxnZN6tV3ldBL23ivZWbuPyhjxic05nHrjh8p7snN+2o4qQ/vU2/7ulMvebIuGuDa03VtfW8uXQjU2cV8p9lG6mrd4b3yWT8yDzOHNa7Ve8UXbutgmsfm8W8whImHD+A75104D7/d1tf77w4fy317pw9PHefv1jfnNLKGk6bNAOAlye2zj0YCnppM68v3sA1j81iRN8sHr581Of92hMen830RRt46fqjYzL3SKwUb6/ihblFPF1QyLIN2+mQnMQph+QwfmQeRw/oHlEo/3flJiY8MYea2nr++I1hnHxITitWLtE2a9VWvvHXmZx2aC/uunB4xD/QFPTSpqbNW8vEKXM49sBsJl+Sz3+WbeTqR2fxg5MPjOoMmPHM3VlYVMrUWWt4Yd5atpXXkNMlhXNH5DJ+ZN4edyw1vNffZnzCHa8s5YDsTvz1kpF79fUSP+55cwV/mL6c348fyvn5kc0hpaCXNjflw9Xc/OwCTj64J3PWbCO7U0demHBUTNsO40VVbR1vLAkN7by1bCP1DiP7ZjF+ZB6nD+31lb/G76iq5UdT5/PSgnWcfmgvfjd+aFwstygtU1fvfOuB95lfWMK/vnt0RD+wFfQSEw0TOrVLMl647qh9qgukrWwsreS5OUU8PauQlRt3kNI+iVMPyWH8yD4ceUC3nbpMPi7ewdWPzuKT4h3cMnYwV4xp+7tcpfWtK6lg7KQZ5GWl8uy1R7X4ngcFvcTMlA9X0yE5iXNHxHc/d6y5O/MKS5g6aw3T5q6ltLKW3hkpnDcyj/NG5LFsw3ZufGoeHZKTuOeiwzgyTtbTldYxfdF6rnp0FleO6c9PTj+4Re+hoBfZh1TW1PH6kg08XVDIjBXF1Ie/RYf1yeS+b42IyYIeEn2/emkxA3p04oKv7deir1fQi+yj1peEhnbq6uu58pj9ozanu+z7ornClIhEUU5GCtced0Csy5B9nFogREQCTkEvIhJwCnoRkYBT0IuIBJyCXkQk4BT0IiIBp6AXEQk4Bb2ISMDF3Z2xZlYMrIrgLboDm1qpnGiI9/og/muM9/og/muM9/pANe6tvu6e3dwLcRf0kTKzgl3dBhwP4r0+iP8a470+iP8a470+UI2tSUM3IiIBp6AXEQm4IAb95FgXsBvxXh/Ef43xXh/Ef43xXh+oxlYTuDF6ERHZWRDP6EVEpBEFvYhIwAUm6M3sVDNbZmYrzezmWNfTlJn1MbP/mNliM1tkZhNjXVNzzKydmc0xs3/FupbmmFmmmU01s6VmtsTMRse6psbM7Hvhf9+FZvaEmaXEQU0PmtlGM1vYaFtXM3vNzFaEP2fFYY2/D/87zzez58wsM57qa/TajWbmZha3C/kGIujNrB1wLzAWOBi4yMxatsJu9NQCN7r7wcARwHVxWCPARGBJrIv4CpOAV939IGAYcVSrmeUC1wP57j4EaAdcGNuqAHgIOLXJtpuBN9x9IPBG+HksPcSXa3wNGOLuQ4HlwC1tXVQjD/Hl+jCzPsDJwOq2LmhvBCLogVHASnf/xN2rgSnAuBjXtBN3X+fus8OPtxMKqNzYVrUzM8sDTgceiHUtzTGzDOAY4O8A7l7t7ttiW9WXJAOpZpYMpAFrY1wP7v4OsKXJ5nHAw+HHDwNnt2lRTTRXo7tPd/fa8NP3gbw2L+yLWpr7OwS4E/ghENddLUEJ+lxgTaPnhcRZiDZmZv2Aw4APYlvJl/yZ0H/a+lgXsgv9gWLgH+HhpQfMLD3WRTVw9yLgD4TO7tYBJe4+PbZV7VJPd18Xfrwe6BnLYvbA/wCvxLqIxsxsHFDk7vNiXcvuBCXo9xlm1gl4BrjB3UtjXU8DMzsD2Ojus2Jdy1dIBkYA97n7YUAZsR9y+Fx4nHscoR9IvYF0M7s4tlXtnod6rOP2jNTMfkJo6POfsa6lgZmlAT8Gbo11LXsiKEFfBPRp9DwvvC2umFl7QiH/T3d/Ntb1NHEUcJaZfUZo6OvrZvZYbEv6kkKg0N0bfhOaSij448WJwKfuXuzuNcCzwJExrmlXNphZL4Dw540xrqdZZnYZcAbwLY+vm34OIPQDfV74eyYPmG1mOTGtaheCEvQfAQPNrL+ZdSB0AWxajGvaiZkZobHlJe7+p1jX05S73+Luee7ej9Df35vuHldno+6+HlhjZoPCm04AFsewpKZWA0eYWVr43/sE4uhicRPTgEvDjy8FXohhLc0ys1MJDSWe5e7lsa6nMXdf4O493L1f+HumEBgR/j8adwIR9OELNhOAfxP6xnrK3RfFtqovOQq4hNCZ8tzwx2mxLmof9F3gn2Y2HxgO/DrG9Xwu/JvGVGA2sIDQ91fMb5E3syeAmcAgMys0s+8AdwAnmdkKQr+J3BGHNd4DdAZeC3+/3B9n9e0zNAWCiEjABeKMXkREdk1BLyIScAp6EZGAU9CLiAScgl5EJOAU9CIiAaegFxEJuP8HvHiKw1jJ554AAAAASUVORK5CYII=\n",
      "text/plain": [
       ""
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "%matplotlib inline\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测\n",
    "用训练好的模型进行预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the input words is: of, william\n",
      "the predict words is: shakespeare\n",
      "the true words is: shakespeare\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "def test(model):\n",
    "    model.eval()\n",
    "    # 从最后10组数据中随机选取1个\n",
    "    idx = random.randint(len(trigram)-10, len(trigram)-1)\n",
    "    print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n",
    "    x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n",
    "    x_data = paddle.to_tensor(np.array(x_data))\n",
    "    predicts = model(x_data)\n",
    "    predicts = predicts.numpy().tolist()[0]\n",
    "    predicts = predicts.index(max(predicts))\n",
    "    print('the predict words is: ' + idx_to_word[predicts])\n",
    "    y_data = trigram[idx][1]\n",
    "    print('the true words is: ' + y_data)\n",
    "test(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}