{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 用N-Gram模型在莎士比亚诗中训练word embedding\n", "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", "本示例在莎士比亚十四行诗上实现了trigram。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境\n", "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.0.0-alpha0'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import paddle\n", "paddle.__version__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集&&相关参数\n", "训练数据集采用了莎士比亚十四行诗,CONTEXT_SIZE设为2,意味着是trigram。EMBEDDING_DIM设为10。" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "CONTEXT_SIZE = 2\n", "EMBEDDING_DIM = 10\n", "\n", "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", "And dig deep trenches in thy beauty's field,\n", "Thy youth's proud livery so gazed on now,\n", "Will be a totter'd weed of small worth held:\n", "Then being asked, where all thy beauty lies,\n", "Where all the treasure of thy lusty days;\n", "To say, within thine own deep sunken eyes,\n", "Were an all-eating shame, and thriftless praise.\n", "How much more praise deserv'd thy beauty's use,\n", "If thou couldst answer 'This fair child of mine\n", "Shall sum my count, and make my old excuse,'\n", "Proving his beauty by succession thine!\n", "This were to be new made when thou art old,\n", "And see thy blood warm when thou feel'st it cold.\"\"\".split()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据预处理\n", "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(('When', 'forty'), 'winters'), (('forty', 'winters'), 'shall'), (('winters', 'shall'), 'besiege')]\n" ] } ], "source": [ "trigram = [((test_sentence[i], test_sentence[i + 1]), test_sentence[i + 2])\n", " for i in range(len(test_sentence) - 2)]\n", "\n", "vocab = set(test_sentence)\n", "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", "# 看一下数据集\n", "print(trigram[:3])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 构建`Dataset`类 加载数据" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "class TrainDataset(paddle.io.Dataset):\n", " def __init__(self, tuple_data, vocab):\n", " self.tuple_data = tuple_data\n", " self.vocab = vocab\n", "\n", " def __getitem__(self, idx):\n", " data = list(self.tuple_data[idx][0])\n", " label = list(self.tuple_data[idx][1])\n", " return data, label\n", " \n", " def __len__(self):\n", " return len(self.tuple_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 组网&训练\n", "这里用paddle动态图的方式组网,由于是N-Gram模型,只需要一层`Embedding`与两层`Linear`就可以完成网络模型的构建。" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import numpy as np\n", "class NGramModel(paddle.nn.Layer):\n", " def __init__(self, vocab_size, embedding_dim, context_size):\n", " super(NGramModel, self).__init__()\n", " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 128)\n", " self.linear2 = paddle.nn.Linear(128, vocab_size)\n", "\n", " def forward(self, x):\n", " x = self.embedding(x)\n", " x = paddle.reshape(x, [1, -1])\n", " x = self.linear1(x)\n", " x = paddle.nn.functional.relu(x)\n", " x = self.linear2(x)\n", " x = paddle.nn.functional.softmax(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 初始化Model,并定义相关的参数。" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, loss is: [4.631529]\n", "epoch: 50, loss is: [4.6081576]\n", "epoch: 100, loss is: [4.600631]\n", "epoch: 150, loss is: [4.603069]\n", "epoch: 200, loss is: [4.592647]\n", "epoch: 250, loss is: [4.5626693]\n", "epoch: 300, loss is: [4.513106]\n", "epoch: 350, loss is: [4.4345813]\n", "epoch: 400, loss is: [4.3238697]\n", "epoch: 450, loss is: [4.1728854]\n", "epoch: 500, loss is: [3.9622664]\n", "epoch: 550, loss is: [3.67673]\n", "epoch: 600, loss is: [3.2998457]\n", "epoch: 650, loss is: [2.8206367]\n", "epoch: 700, loss is: [2.2514927]\n", "epoch: 750, loss is: [1.6479329]\n", "epoch: 800, loss is: [1.1147357]\n", "epoch: 850, loss is: [0.73231363]\n", "epoch: 900, loss is: [0.49481753]\n", "epoch: 950, loss is: [0.3504072]\n" ] } ], "source": [ "vocab_size = len(vocab)\n", "embedding_dim = 10\n", "context_size = 2\n", "\n", "paddle.enable_imperative()\n", "losses = []\n", "def train(model):\n", " model.train()\n", " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", " for epoch in range(1000):\n", " # 留最后10组作为预测\n", " for context, target in trigram[:-10]:\n", " context_idxs = list(map(lambda w: word_to_idx[w], context))\n", " x_data = paddle.imperative.to_variable(np.array(context_idxs))\n", " y_data = paddle.imperative.to_variable(np.array([word_to_idx[target]]))\n", " predicts = model(x_data)\n", " # print (predicts)\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", " loss.backward()\n", " optim.minimize(loss)\n", " model.clear_gradients()\n", " if epoch % 50 == 0:\n", " print(\"epoch: {}, loss is: {}\".format(epoch, loss.numpy()))\n", " losses.append(loss.numpy())\n", "model = NGramModel(vocab_size, embedding_dim, context_size)\n", "train(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 打印loss下降曲线\n", "通过可视化loss的曲线,可以看到模型训练的效果。" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdkklEQVR4nO3dd3xV9f3H8dfn3uwBmUAggRghMmUFDQ5Qq4JC3VqtVmuttL/aqq3WamtdtcPaarWOuq1aR0XrAJVaxKKIIghE9oaAAQIBAiE7398fuVhUJCHk5tzxfj4e95Hcc06S98nJ451zv/cMc84hIiKhy+d1ABER2T8VtYhIiFNRi4iEOBW1iEiIU1GLiIS4mGB806ysLJefnx+Mby0iEpHmzJmzxTmXva95QSnq/Px8Zs+eHYxvLSISkcxs7dfN09CHiEiIU1GLiIQ4FbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIC8px1G318PSVNDmI8Vnzw+8jPsZHVko82anxZKXEk5kSR6xf/19EJHqEVFHf9fYyauqbWlwuIzmOrJQ4slPjyU5pLvCs1HhS4mNIjPWTGOcnIdZHjM+H32dffJi1OM3s4NclxucjMbY5h7XHNxSRqBVSRV1y8xgampqob3Q0NjkaGpuorm9ky646tuyqpXxn7Vc+zlm3jfKdta0qeK8kxvpJivOTEPiYGOf//B/K3tPTEuPITIkjMyWerOTmj5kpcaQnxeH3qexFolVIFXVcjI+4fQyb98pM3u/XOeeoqmtkd20DNfXN5V5d30hDYxONTY5G11z8X3nsa7pztMdNb5r/yQSy1DVQXd/I7rpGagIfq+saqaiqY8O2wPP6RnZU19PY9NUf7rPmVxGZyfF06RRPbnoiuelJ9EhLJDc9kR7piXRJTVCZi0SokCrqtjIzUuJjSIkP79VpanLsqK5na1Xt568itu6qY+uuWrZU1bFlZy2bKmt4u6ySLbvqvvC1sX4jp3OguNMCRZ6eGCj1RLp3TsSnIhcJS+HdbBHG5zPSk+NIT46jd5f9L1td18iG7dWs37Y78LGaDduan09fXs7mnbVfeGWQGOunsFsqfbum0jcnlcO6pdI/pxNpSXHBXSkROWgq6jCVGOend5cUendJ2ef82oZGyrbXsGF7NaUVu1m2aRdLNlby9uJNvDC79PPl8jOTGJyXxuDcNAbnpTGgeycSYv0dtRoi0goq6ggVH+MnPyuZ/Kwvju875yjfVcuSsp0s+GwHJaU7mLW6glfnfQY0HxrZL6cTg/M6MyI/gxH5GXRPS/RiFUQkwFx7vHP2JUVFRU7Xow4vmyprmFe6nfml25m/fjvzS3ewq7YBgNz0RI7Iz2DEIRkc0zuLvIwkj9OKRB4zm+OcK9rnPBW17Etjk2NxWSWzVlfw8ZoKZq2uYGtV8xuYBdnJjC7MZlRhNiMLMjVUItIOVNRy0JxzrCzfxfRlW/jvsnI+XLWV2oYmEmP9HN83mzEDunFC3y6kJsR6HVUkLKmopd3V1Dfy4aqt/GfxJqYs3ET5zlri/D6O6ZPFqYNyOKl/VzonqrRFWktFLUHV1OSYW7qNtxZs5M0FG1m/rZo4v49RhVmMOzyHE/t11Z62SAtU1NJhnHPMX7+DySWfMbmkjM921BAX42N0YTbjD8/hG/26hv2JSSLBoKIWTzTvaW9nckkZb3xaxsbKGuJjfIwd2I1zhudy1KFZOu1dJEBFLZ5ranLMWbeNV+dt4LV5n1FZ00BO5wTOGtaDc4bncUjW/q/nIhLpVNQSUmrqG5m6eDMT55Ty32XlNDkYVZjNpUfnM7pPtq5JIlFJRS0ha1NlDS98XMozH65l885aCrKSueSofM4enquxbIkqKmoJeXUNTby5oIzHZ6xhful2UhNiuHhkLy47poCMZF04SiKfilrCytx123jkvVW8uWAjCTF+LiruyeWjCuiSmuB1NJGgUVFLWFq+aScPvLuSV+dtIMbv4zvFvbji+N7aw5aIpKKWsLZmSxX3TVvBy5+sJzkuhh+MLuB7xxxCUpzGsCVyqKglIizbtJM7pyzl7UWbyE6N5/qxfTlzaA8dJSIRYX9F/dUbFIqEqMKuqTxycRETfziS7mmJXPPifM59aCYLNuzwOppIUKmoJewU5Wfwr/87ij+efThrtlRx2n3vc+Mrn7J9d13LXywShlTUEpZ8PuO8EXm8c+1xXDwyn2c/Wsfxf3qXf35cSjCG80S8pKKWsNY5MZZbThvA5CuPpU+XVK57qYTvPDaL0ordXkcTaTetLmoz85vZXDObFMxAIm3RL6cTz08o5vYzBjJ33TZOvns6T8xYTVOT9q4l/B3IHvVVwOJgBRE5WD6fcVFxL/79s9EccUgGt76+iHMfmsmKzbu8jiZyUFpV1GaWC4wDHg1uHJGD1yMtkScvHcFd5w1mZfkuTr33Pe6ftoL6xiavo4m0SWv3qP8CXAd87V+6mU0ws9lmNru8vLw9som0mZlx1rBc3v7paE7s14U7pyzl7Ac/YPWWKq+jiRywFovazMYDm51zc/a3nHPuYedckXOuKDs7u90CihyM7NR4HrhwOA9cOIy1W3cz7t73mDhnvY4MkbDSmj3qo4HTzGwN8Dxwgpk9E9RUIu3s1EE5vHnVsQzq0ZlrX5zPlc/PY2dNvdexRFqlxaJ2zt3gnMt1zuUD5wPvOOcuCnoykXbWPS2RZy8v5tqTC5lc8hmn3TeDJRsrvY4l0iIdRy1Rxe8zfnxCH569vJhdtQ2ccf8MXpxd6nUskf06oKJ2zr3rnBsfrDAiHaW4IJPJVx7D0Lx0fj6xhF+/skBHhUjI0h61RK0uqQk88/0j+cGoAp7+cC3ffWIWO3Zr3FpCj4paoprfZ9xwaj/uPOdwZq2u4IwHZrCyXCfISGhRUYsA5xbl8dzlxVRW13Pm/TP4cNVWryOJfE5FLRJQlJ/BK1ccTZdOCVz8+CzeXrTJ60gigIpa5AvyMpJ48Qcj6ZfTiR8+M4eJc9Z7HUlERS3yZenJcTz7/SM56tBMrn1xPn//YI3XkSTKqahF9iE5PoZHLyni5P5dufm1hTwxY7XXkSSKqahFvkZ8jJ/7LxzGmAFdufX1RTz+vspavKGiFtmPWL+P+749jFMGduO2SYt49L1VXkeSKKSiFmlBrN/HvRcM5dRB3bh98mKe1DCIdLAYrwOIhINYv497zh9KQ+Mn3PL6IpLjYzi3KM/rWBIltEct0kqxfh9//fZQju2TxS9eKmFySZnXkSRKqKhFDkB8jJ+HvjOcYT3TufqFuUxbstnrSBIFVNQiBygpLobHLx3BYd1S+eEzc3S6uQSdilqkDTolxPL3S48gLyOJy578mHml272OJBFMRS3SRpkp8Txz2ZFkpMRxyeOzWLpxp9eRJEKpqEUOQrfOCTz7/WLiY3xc+sQsNlXWeB1JIpCKWuQg5WUk8fh3R7C9up7vPfkxVbUNXkeSCKOiFmkHA3t05v5vD2NxWSU/eW4uDbqtl7QjFbVIOzm+bxduO30g7yzZzC2vL8Q553UkiRA6M1GkHV1U3IvSbbt56L+r6JmRxIRRh3odSSKAilqknf1iTF/WV1TzuzeW0CszmTEDunkdScKchj5E2pnPZ/z5vMEMzu3MNf+cz4rNOmxPDo6KWiQIEmL9PHjRcOJjfEx4eg47a+q9jiRhTEUtEiTd0xK5/8JhrN26m5/9cz5NTXpzUdpGRS0SRMUFmdw4rh9vL9rEfdNWeB1HwpSKWiTIvntUPmcN7cHd/1nG1MWbvI4jYUhFLRJkZsbvzhpE/5xOXP38PFaV7/I6koQZFbVIB0iIbb6OdYzfuOLZudTUN3odScKIilqkg+SmJ/Hn8wazuKyS2ycv8jqOhBEVtUgHOqFvVyaMKuCZD9cxqeQzr+NImFBRi3Swn485jKE907jhpU9Zu7XK6zgSBlTUIh0s1u/j3vOHYgY/eW4u9brSnrRARS3igbyMJO44+3BK1u/gr+/o+GrZPxW1iEdOGZTDWcN6cP+0FXyybpvXcSSEqahFPHTLaQPo1imBn70wj911ujOM7FuLRW1mCWY2y8zmm9lCM7u1I4KJRINOCbH8+bzBrK3YzW8nL/Y6joSo1uxR1wInOOcGA0OAsWZWHNRUIlGkuCCTy48t4B8frWPaks1ex5EQ1GJRu2Z7znmNDTx0GTCRdnTNyYX07ZbKzyeWUFFV53UcCTGtGqM2M7+ZzQM2A2875z7axzITzGy2mc0uLy9v55gikS0+xs/d3xpCZXU9N7xcovstyhe0qqidc43OuSFALnCEmQ3cxzIPO+eKnHNF2dnZ7RxTJPL1y+nEtWMKmbJwE/+au8HrOBJCDuioD+fcdmAaMDYoaUSi3GXHFDC8Vzq3TVpE+c5ar+NIiGjNUR/ZZpYW+DwROAlYEuRcIlHJ7zPuOHsQu2sbueW1hV7HkRDRmj3qHGCamZUAH9M8Rj0puLFEolfvLqlcdWIfJn9axlsLNnodR0JATEsLOOdKgKEdkEVEAiaMKmBySRm/fnUBIwsy6ZwU63Uk8ZDOTBQJQbF+H38853Aqqur4ja5dHfVU1CIhamCPzvxwdAET56xn+jId8hrNVNQiIewnJ/Th0Oxkbnj5U6pqdS2QaKWiFglhCbF+7jj7cDZsr+beqcu9jiMeUVGLhLii/AzOH5HHY++vZunGnV7HEQ+oqEXCwHVj+5KSEMOvX1mg08ujkIpaJAxkJMdx/di+zFpTwcuf6PTyaKOiFgkT5xXlMbRnGr9/czE7dtd7HUc6kIpaJEz4fMbtZwykoqqOP/17qddxpAOpqEXCyIDunbl4ZD7PfLSWkvXbvY4jHURFLRJmfnZyIVkp8dz4ygIam/TGYjRQUYuEmU4Jsdw4rh8l63fw7Kx1XseRDqCiFglDpw3uzlGHZnLnW0vYskvXrY50KmqRMGRm3Hb6QKrrG/n9G7o8fKRTUYuEqd5dUrj82AJe+mQ9H63a6nUcCSIVtUgY+8kJfeiRlshNry6kobHJ6zgSJCpqkTCWGOfn1+P7sXTTTp77uNTrOBIkKmqRMDdmQDeKCzK4699LdcZihFJRi4Q5M+Om8QPYUV3PX6Yu8zqOBIGKWiQC9O/eifOP6MlTM9eyYrMuhRppVNQiEeKakwpJivNz26TFuhRqhFFRi0SIzJR4rvpGH6YvK2fa0s1ex5F2pKIWiSAXj8ynIDuZ2yctpq5Bh+tFChW1SASJi/Hx63H9WbWliqc/XOt1HGknKmqRCHN83y4c2yeLe6cu1+F6EUJFLRKBfnlqPypr6rlvmu5cHglU1CIRqF9OJ84dnsvfP1jLuq27vY4jB0lFLRKhrjn5MPw+444purpeuFNRi0Sorp0SmDCqgMklZcxZu83rOHIQVNQiEWzCqAKyU+P53Rs6CSacqahFIlhyfAzXnFTInLXbeGvBRq/jSBupqEUi3LlFeRzWNZU/vLVEJ8GEKRW1SITz+4xfjuvH2q27dRJMmFJRi0SB0YXZOgkmjKmoRaKEToIJXypqkSihk2DCV4tFbWZ5ZjbNzBaZ2UIzu6ojgolI+9NJMOGpNXvUDcA1zrn+QDFwhZn1D24sEQkGnQQTnlosaudcmXPuk8DnO4HFQI9gBxOR4NBJMOHngMaozSwfGAp8tI95E8xstpnNLi8vb6d4ItLe9j4J5k2dBBMWWl3UZpYCvARc7Zyr/PJ859zDzrki51xRdnZ2e2YUkXb2+Ukwb+okmHDQqqI2s1iaS/ofzrmXgxtJRIJtz0kw6yp289TMNV7HkRa05qgPAx4DFjvn7gp+JBHpCKMLsxldmM09U5dTUVXndRzZj9bsUR8NfAc4wczmBR6nBjmXiHSAX43rR1VtA/dO1UkwoSympQWcc+8D1gFZRKSDFXZN5YIjevL0h2u5qLgXvbukeB1J9kFnJopEuZ+eVEhSrJ/fv7HY6yjyNVTUIlEuKyWeK07ozdQlm3l/+Rav48g+qKhFhO8elU9ueiK3T15EY5NOggk1KmoRISHWz/Wn9GXJxp1MnFPqdRz5EhW1iAAwblAOw3ulc+eUZeyqbfA6juxFRS0iAJgZN47rx5Zdtfzt3ZVex5G9qKhF5HNDe6Zz+pDuPPLeKjZsr/Y6jgSoqEXkC64b2xeAP76la1aHChW1iHxBj7REvn/sIbw67zPmrtM1q0OBilpEvuL/jutNVko8t0/WNatDgYpaRL4iJT6Ga09uvmb15E/LvI4T9VTUIrJP5xbl0bdbKr9/YwnVdY1ex4lqKmoR2Se/z7j1tAFs2F7Ng++u8DpOVFNRi8jXOrIgk9OHdOdv01exdmuV13GilopaRPbrl6f2I9Zn/GbSIq+jRC0VtYjsV9dOCVz5jT78Z/Fm3lmyyes4UUlFLSItuvToQyjITubW1xdRU683FjuailpEWhQX4+OWbw5g7dbdPPb+aq/jRB0VtYi0yqjCbMYO6MZ976zgM10HpEOpqEWk1W4c348m5/itbtvVoVTUItJquelJXHF8byaXlDF9WbnXcaKGilpEDsiEUQUUZCXzq1c+1RmLHURFLSIHJCHWz2/PHERpRTX3TF3udZyooKIWkQM28tBMzh2eyyPvrWJxWaXXcSKeilpE2uSXp/ajc2IsN7z8qe5cHmQqahFpk/TkOG4a3595pdt55sO1XseJaCpqEWmz04d059g+Wdzx1hJKK3Z7HSdiqahFpM3MjN+fNQifGddNLKFJQyBBoaIWkYOSm57EjeP6MXPVVp7WEEhQqKhF5KB9a0Qeowuz+cObS1izRdetbm8qahE5aGbGH84eRIzf+PnE+RoCaWcqahFpFzmdE7n5mwP4eM02Hn1/lddxIoqKWkTazdnDejBmQFfunLKUBRt2eB0nYqioRaTdmBl/OOtwMpPjufK5uVTVNngdKSKoqEWkXaUnx3H3t4awemsVt72u+yy2BxW1iLS7kYdm8qPjDuWF2aVMLinzOk7Ya7GozexxM9tsZgs6IpCIRIarTyxkaM80fvFSCSvLd3kdJ6y1Zo/6SWBskHOISISJ9fu4/9vDiI/x8cOn57BL49Vt1mJRO+emAxUdkEVEIkz3tET+esFQVpbv4rqJ83FOx1e3hcaoRSSojuqdxfWn9OWNTzfy8HQdX90W7VbUZjbBzGab2ezyct1LTUT+5/JjCxg3KIc73lrCjBVbvI4TdtqtqJ1zDzvnipxzRdnZ2e31bUUkApgZfzzncHp3SeFH//hEby4eIA19iEiHSI6P4bFLRhDrNy594mO27qr1OlLYaM3hec8BM4HDzGy9mV0W/FgiEonyMpJ45OIiNlXW8P2nZlNTr7uYt0Zrjvq4wDmX45yLdc7lOuce64hgIhKZhvZM557zhzCvdDtXPT+XhsYmryOFPA19iEiHGzswh5vG92fKwk26M0wrxHgdQESi06VHH0JVbQN/+vcyEuL8/PaMgZiZ17FCkopaRDxzxfG9qapr5MF3V5IY6+fGcf1U1vugohYRz5gZ1405jOq6Rh57fzVNznHT+P4q6y9RUYuIp8yMm7/ZHzN4YsYaqusa+e2Zg/D7VNZ7qKhFxHNmxk3j+5McF8N901awu66RP583mFi/jncAFbWIhAgz49oxh5EU7+ePby1la1UtD1w4nM6JsV5H85z+XYlISPnRcb3507mDmbW6grMf/IDSit1eR/KcilpEQs45w3N56ntHsrmyhjMfmMGctdu8juQpFbWIhKSRh2by8o+OJjk+hvMfnsmTM1ZH7fWsVdQiErJ6d0nhtSuOYXRhNre8vogrn58XlXc2V1GLSEjrnBTLw98p4udjDmNyyWeM/+v7zF0XXUMhKmoRCXk+n3HF8b159vJi6hqaOOdvM7nr7WXUR8kFnVTUIhI2igsyefPqYzl9SHfunbqcsx74gAUbdngdK+hU1CISVjolxHLXeUN48MJhlO2o4bT73ufW1xeys6be62hBo6IWkbB0yqAcpl4zmguP7MWTH6zhxLv+y6vzNkTkJVNV1CIStjonxvKbMwbyrx8dTVZKPFc9P4/T75/BBxF2A10VtYiEvSF5abz+42O467zBVFTV8e1HP+K7T8yKmPFrC8YB5EVFRW727Nnt/n1FRFpSU9/IUzPXcN87K6isaeD4w7L58Qm9Gd4rw+to+2Vmc5xzRfucp6IWkUhUWVPP0zPX8tj7q6moqqO4IIMJowo4rrALvhC8hKqKWkSi1u66Bp6bVcoj01exsbKGnhlJXFTck/OK8khLivM63udU1CIS9eobm5iycCNPzVzLrNUVxMf4GH94d84e3oPiQzI938tWUYuI7GVxWSVPf7iW1+Z9xq7aBnqkJXLG0O6cOTSX3l1SPMmkohYR2YfqukbeXryJlz9Zz/Rl5TQ5KOyawtgB3RgzsBv9czp12P0bVdQiIi3YXFnD5E/LmLJwI7NWV9DkIC8jkRP7dWVUYTbFh2SSGOcP2s9XUYuIHIAtu2r5z6JNvLVwIzNXbqW2oYm4GB9H5GcwqjCLo3tn0bdbp3a9Aa+KWkSkjWrqG5m1uoLpy8qZvrycZZt2AZCaEMPwXumMyM9gRH4Gh+d2JiG27Xvc+ytq3dxWRGQ/EmL9jCrMZlRhNgAbd9Tw4aqtzFpTwcerK3h36VIA4vw+huSl8fyE4nY/gkRFLSJyALp1TuCMoT04Y2gPALZV1TF77TZmr6lgR3V9UA7zU1GLiByE9OQ4TurflZP6dw3az9BFmUREQpyKWkQkxKmoRURCnIpaRCTEqahFREKcilpEJMSpqEVEQpyKWkQkxAXlWh9mVg6sbeOXZwGRdQvhlmmdo4PWOfIdzPr2cs5l72tGUIr6YJjZ7K+7MEmk0jpHB61z5AvW+mroQ0QkxKmoRURCXCgW9cNeB/CA1jk6aJ0jX1DWN+TGqEVE5ItCcY9aRET2oqIWEQlxIVPUZjbWzJaa2Qozu97rPO3FzPLMbJqZLTKzhWZ2VWB6hpm9bWbLAx/TA9PNzO4N/B5KzGyYt2vQdmbmN7O5ZjYp8PwQM/sosG4vmFlcYHp84PmKwPx8T4O3kZmlmdlEM1tiZovNbGSkb2cz+2ng73qBmT1nZgmRtp3N7HEz22xmC/aadsDb1cwuCSy/3MwuOZAMIVHUZuYH7gdOAfoDF5hZf29TtZsG4BrnXH+gGLgisG7XA1Odc32AqYHn0Pw76BN4TAAe7PjI7eYqYPFez+8A7nbO9Qa2AZcFpl8GbAtMvzuwXDi6B3jLOdcXGEzzukfsdjazHsCVQJFzbiDgB84n8rbzk8DYL007oO1qZhnAzcCRwBHAzXvKvVWcc54/gJHAlL2e3wDc4HWuIK3rq8BJwFIgJzAtB1ga+Pwh4IK9lv98uXB6ALmBP+ATgEmA0XzGVsyXtzkwBRgZ+DwmsJx5vQ4HuL6dgdVfzh3J2xnoAZQCGYHtNgkYE4nbGcgHFrR1uwIXAA/tNf0Ly7X0CIk9av63wfdYH5gWUQIv9YYCHwFdnXNlgVkbgT03XIuU38VfgOuApsDzTGC7c64h8Hzv9fp8nQPzdwSWDyeHAOXAE4HhnkfNLJkI3s7OuQ3An4B1QBnN220Okb2d9zjQ7XpQ2ztUijrimVkK8BJwtXOucu95rvlfbMQcJ2lm44HNzrk5XmfpQDHAMOBB59xQoIr/vRwGInI7pwOn0/xPqjuQzFeHCCJeR2zXUCnqDUDeXs9zA9MigpnF0lzS/3DOvRyYvMnMcgLzc4DNgemR8Ls4GjjNzNYAz9M8/HEPkGZmMYFl9l6vz9c5ML8zsLUjA7eD9cB659xHgecTaS7uSN7OJwKrnXPlzrl64GWat30kb+c9DnS7HtT2DpWi/hjoE3i3OI7mNyRe8zhTuzAzAx4DFjvn7tpr1mvAnnd+L6F57HrP9IsD7x4XAzv2eokVFpxzNzjncp1z+TRvy3eccxcC04BzAot9eZ33/C7OCSwfVnuezrmNQKmZHRaY9A1gERG8nWke8ig2s6TA3/medY7Y7byXA92uU4CTzSw98Erk5MC01vF6kH6vwfVTgWXASuBXXudpx/U6huaXRSXAvMDjVJrH5qYCy4H/ABmB5Y3mI2BWAp/S/I665+txEOt/HDAp8HkBMAtYAbwIxAemJwSerwjML/A6dxvXdQgwO7CtXwHSI307A7cCS4AFwNNAfKRtZ+A5msfg62l+5XRZW7Yr8L3Auq8ALj2QDDqFXEQkxIXK0IeIiHwNFbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIU1GLiIS4/wecQTmUPnjbjwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "%matplotlib inline\n", "\n", "plt.figure()\n", "plt.plot(losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 预测\n", "用训练好的模型进行预测。" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the input words is: praise., How\n", "the predict words is: much\n", "the true words is: much\n" ] } ], "source": [ "import random\n", "def test(model):\n", " model.eval()\n", " # 从最后10组数据中随机选取1个\n", " idx = random.randint(len(trigram)-10, len(trigram)-1)\n", " print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n", " x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n", " x_data = paddle.imperative.to_variable(np.array(x_data))\n", " predicts = model(x_data)\n", " predicts = predicts.numpy().tolist()[0]\n", " predicts = predicts.index(max(predicts))\n", " print('the predict words is: ' + idx_to_word[predicts])\n", " y_data = trigram[idx][1]\n", " print('the true words is: ' + y_data)\n", "test(model)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.3 64-bit", "language": "python", "name": "python_defaultSpec_1598180286976" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3-final" } }, "nbformat": 4, "nbformat_minor": 4 }