提交 67394c17 编写于 作者: T TC.Long

add two_docs with develop paddle

上级 a7c242a6
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST数据集使用LeNet进行图像分类\n",
"本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n",
"手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop版本。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)\n",
"paddle.disable_static()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 加载数据集\n",
"我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"download training data and load training data\n",
"load finished\n"
]
}
],
"source": [
"print('download training data and load training data')\n",
"train_dataset = paddle.vision.datasets.MNIST(mode='train')\n",
"test_dataset = paddle.vision.datasets.MNIST(mode='test')\n",
"print('load finished')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"取训练集中的一条数据看一下。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_data0 label is: [5]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkjCFmRdj3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H97ghom3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A/ZGSLFUZBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbO5tLBIAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n",
"train_data0 = train_data0.reshape([28,28])\n",
"plt.figure(figsize=(2,2))\n",
"plt.imshow(train_data0, cmap=plt.cm.binary)\n",
"print('train_data0 label is: ' + str(train_label_0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.组网\n",
"用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.nn.functional as F\n",
"class LeNet(paddle.nn.Layer):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
" self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n",
" self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n",
" self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n",
" self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.relu(x)\n",
" x = self.max_pool1(x)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n",
" x = self.linear1(x)\n",
" x = F.relu(x)\n",
" x = self.linear2(x)\n",
" x = F.relu(x)\n",
" x = self.linear3(x)\n",
" x = F.softmax(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练方式一\n",
"通过`Model` 构建实例,快速完成模型训练"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle.static import InputSpec\n",
"from paddle.metric import Accuracy\n",
"inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n",
"model = paddle.hapi.Model(LeNet(), inputs, labels)\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"\n",
"\n",
"model.prepare(\n",
" optim,\n",
" paddle.nn.loss.CrossEntropyLoss(),\n",
" Accuracy(topk=(1, 2))\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 使用model.fit来训练模型"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n",
"step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n",
"step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n",
"step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n",
"step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n",
"step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n",
"step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n",
"step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n",
"step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n",
"step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n",
"step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n",
"step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n",
"step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n",
"step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n",
"step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n",
"step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n",
"step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n",
"step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n",
"step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n",
"step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n",
"step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n",
"step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n",
"step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n",
"step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n",
"step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n",
"step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n",
"step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n",
"step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n",
"step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n",
"step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n",
"step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n",
"step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n",
"step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n",
"step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n",
"step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n",
"step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n",
"step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n",
"step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n",
"step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n",
"step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n",
"step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n",
"step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n",
"step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n",
"step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n",
"step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n",
"step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n",
"step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n",
"step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n",
"step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n",
"step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n",
"step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n",
"step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n",
"step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n",
"step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n",
"step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n",
"step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n",
"step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n",
"step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n",
"step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n",
"step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n",
"step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n",
"step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n",
"step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n",
"step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n",
"step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n",
"step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n",
"step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n",
"step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n",
"step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n",
"step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n",
"step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n",
"step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n",
"step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n",
"step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n",
"step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n",
"step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n",
"step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n",
"step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n",
"step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n",
"step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n",
"step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n",
"step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n",
"step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n",
"step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n",
"step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n",
"step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n",
"step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n",
"step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n",
"step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n",
"step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n",
"step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n",
"step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n",
"step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n",
"step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n",
"step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n",
"step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n",
"step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n",
"step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n",
"step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n",
"step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n",
"step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n",
"step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n",
"step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n",
"step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n",
"step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n",
"step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n",
"step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n",
"Eval samples: 10000\n",
"Epoch 2/2\n",
"step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n",
"step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n",
"step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n",
"step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n",
"step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n",
"step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n",
"step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n",
"step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n",
"step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n",
"step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n",
"step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n",
"step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n",
"step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n",
"step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n",
"step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n",
"step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n",
"step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n",
"step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n",
"step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n",
"step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n",
"step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n",
"step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n",
"step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n",
"step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n",
"step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n",
"step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n",
"step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n",
"step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n",
"step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n",
"step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n",
"step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n",
"step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n",
"step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n",
"step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n",
"step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n",
"step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n",
"step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n",
"step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n",
"step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n",
"step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n",
"step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n",
"step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n",
"step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n",
"step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n",
"step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n",
"step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n",
"step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n",
"step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n",
"step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n",
"step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n",
"step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n",
"step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n",
"step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n",
"step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n",
"step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n",
"step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n",
"step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n",
"step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n",
"step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n",
"step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n",
"step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n",
"step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n",
"step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n",
"step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n",
"step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n",
"step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n",
"step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n",
"step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n",
"step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n",
"step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n",
"step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n",
"step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n",
"step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n",
"step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n",
"step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n",
"step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n",
"step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n",
"step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n",
"step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n",
"step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n",
"step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n",
"step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n",
"step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n",
"step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n",
"step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n",
"step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n",
"step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n",
"step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n",
"step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n",
"step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n",
"step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n",
"step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n",
"step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n",
"step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n",
"step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n",
"step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n",
"step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n",
"step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n",
"step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n",
"step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n",
"Eval samples: 10000\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n"
]
}
],
"source": [
"model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=2,\n",
" batch_size=64,\n",
" save_dir='mnist_checkpoint')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练方式1结束\n",
"以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.训练方式2\n",
"方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n",
"epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n",
"epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n",
"epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n",
"epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n",
"epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n",
"epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n",
"epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n",
"epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n",
"epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n",
"epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n",
"epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n",
"epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n",
"epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n",
"epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n",
"epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n"
]
}
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def train(model):\n",
" model.train()\n",
" epochs = 2\n",
" batch_size = 64\n",
" optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = LeNet()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对模型进行验证"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [1.5017498], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n",
"batch_id: 200, loss is: [1.4958509], acc is: [1.]\n",
"batch_id: 300, loss is: [1.4924574], acc is: [1.]\n",
"batch_id: 400, loss is: [1.4762049], acc is: [1.]\n",
"batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n",
"batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n",
"batch_id: 700, loss is: [1.4954926], acc is: [1.]\n",
"batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n",
"batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n"
]
}
],
"source": [
"import paddle\n",
"test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练方式2结束\n",
"以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 总结\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 用N-Gram模型在莎士比亚文集中训练word embedding\n",
"N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n",
"N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n",
"本示例在莎士比亚文集上实现了trigram。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop。"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.0.0'"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import paddle\n",
"paddle.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据集&&相关参数\n",
"训练数据集采用了莎士比亚文集,[下载](https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt),保存为txt格式即可。<br>\n",
"context_size设为2,意味着是trigram。embedding_dim设为256。"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"embedding_dim = 256\n",
"context_size = 2"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 1115394 characters\n"
]
}
],
"source": [
"# 文件路径\n",
"path_to_file = './shakespeare.txt'\n",
"test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
"\n",
"# 文本长度是指文本中的字符个数\n",
"print ('Length of text: {} characters'.format(len(test_sentence)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 去除标点符号\n",
"用`string`库中的punctuation,完成英文符号的替换。"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n"
]
}
],
"source": [
"from string import punctuation\n",
"process_dicts={i:'' for i in punctuation}\n",
"print(process_dicts)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12848\n"
]
}
],
"source": [
"punc_table = str.maketrans(process_dicts)\n",
"test_sentence = test_sentence.translate(punc_table)\n",
"test_sentence = test_sentence.lower().split()\n",
"vocab = set(test_sentence)\n",
"print(len(vocab))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据预处理\n",
"将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[['first', 'citizen'], 'before'], [['citizen', 'before'], 'we'], [['before', 'we'], 'proceed']]\n"
]
}
],
"source": [
"trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n",
" for i in range(len(test_sentence) - 2)]\n",
"\n",
"word_to_idx = {word: i for i, word in enumerate(vocab)}\n",
"idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n",
"# 看一下数据集\n",
"print(trigram[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 构建`Dataset`类 加载数据\n",
"用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"batch_size = 256\n",
"paddle.disable_static()\n",
"class TrainDataset(paddle.io.Dataset):\n",
" def __init__(self, tuple_data):\n",
" self.tuple_data = tuple_data\n",
"\n",
" def __getitem__(self, idx):\n",
" data = self.tuple_data[idx][0]\n",
" label = self.tuple_data[idx][1]\n",
" data = np.array(list(map(lambda w: word_to_idx[w], data)))\n",
" label = np.array(word_to_idx[label])\n",
" return data, label\n",
" \n",
" def __len__(self):\n",
" return len(self.tuple_data)\n",
"train_dataset = TrainDataset(trigram)\n",
"train_loader = paddle.io.DataLoader(train_dataset,places=paddle.fluid.cpu_places(), return_list=True,\n",
" shuffle=True, batch_size=batch_size, drop_last=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 组网&训练\n",
"这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"hidden_size = 1024\n",
"class NGramModel(paddle.nn.Layer):\n",
" def __init__(self, vocab_size, embedding_dim, context_size):\n",
" super(NGramModel, self).__init__()\n",
" self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n",
" self.linear1 = paddle.nn.Linear(context_size * embedding_dim, hidden_size)\n",
" self.linear2 = paddle.nn.Linear(hidden_size, len(vocab))\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x = paddle.reshape(x, [-1, context_size * embedding_dim])\n",
" x = self.linear1(x)\n",
" x = paddle.nn.functional.relu(x)\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 定义`train()`函数,对模型进行训练。"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [9.460927]\n",
"epoch: 0, batch_id: 100, loss is: [6.8197966]\n",
"epoch: 0, batch_id: 200, loss is: [7.0330124]\n",
"epoch: 0, batch_id: 300, loss is: [6.9797225]\n",
"epoch: 0, batch_id: 400, loss is: [6.9343157]\n",
"epoch: 0, batch_id: 500, loss is: [7.0574594]\n",
"epoch: 0, batch_id: 600, loss is: [6.646331]\n",
"epoch: 0, batch_id: 700, loss is: [6.867292]\n",
"epoch: 1, batch_id: 0, loss is: [6.328604]\n",
"epoch: 1, batch_id: 100, loss is: [6.5479784]\n",
"epoch: 1, batch_id: 200, loss is: [6.1158056]\n",
"epoch: 1, batch_id: 300, loss is: [6.1452174]\n",
"epoch: 1, batch_id: 400, loss is: [6.0391626]\n",
"epoch: 1, batch_id: 500, loss is: [6.3031235]\n",
"epoch: 1, batch_id: 600, loss is: [6.2294545]\n",
"epoch: 1, batch_id: 700, loss is: [7.0351915]\n",
"epoch: 2, batch_id: 0, loss is: [5.9047413]\n",
"epoch: 2, batch_id: 100, loss is: [5.9303236]\n",
"epoch: 2, batch_id: 200, loss is: [6.0062647]\n",
"epoch: 2, batch_id: 300, loss is: [5.923294]\n",
"epoch: 2, batch_id: 400, loss is: [6.0315046]\n",
"epoch: 2, batch_id: 500, loss is: [6.00991]\n",
"epoch: 2, batch_id: 600, loss is: [6.2324004]\n",
"epoch: 2, batch_id: 700, loss is: [5.924728]\n",
"epoch: 3, batch_id: 0, loss is: [5.9861994]\n",
"epoch: 3, batch_id: 100, loss is: [5.886199]\n",
"epoch: 3, batch_id: 200, loss is: [6.216494]\n",
"epoch: 3, batch_id: 300, loss is: [6.066014]\n",
"epoch: 3, batch_id: 400, loss is: [5.7648573]\n",
"epoch: 3, batch_id: 500, loss is: [6.0842447]\n",
"epoch: 3, batch_id: 600, loss is: [6.032789]\n",
"epoch: 3, batch_id: 700, loss is: [5.875885]\n",
"epoch: 4, batch_id: 0, loss is: [5.923893]\n",
"epoch: 4, batch_id: 100, loss is: [5.838437]\n",
"epoch: 4, batch_id: 200, loss is: [5.8900037]\n",
"epoch: 4, batch_id: 300, loss is: [5.826645]\n",
"epoch: 4, batch_id: 400, loss is: [5.978036]\n",
"epoch: 4, batch_id: 500, loss is: [6.1803474]\n",
"epoch: 4, batch_id: 600, loss is: [5.9366255]\n",
"epoch: 4, batch_id: 700, loss is: [6.2263923]\n"
]
}
],
"source": [
"import time\n",
"vocab_size = len(vocab)\n",
"epochs = 5\n",
"losses = []\n",
"def train(model):\n",
" model.train()\n",
" optim = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" y_data = paddle.nn.functional.one_hot(y_data, len(vocab))\n",
" loss = paddle.nn.functional.softmax_with_cross_entropy(predicts, y_data,soft_label=True)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" losses.append(avg_loss.numpy())\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy())) \n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = NGramModel(vocab_size, embedding_dim, context_size)\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 打印loss下降曲线\n",
"通过可视化loss的曲线,可以看到模型训练的效果。"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1457b58d0>]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deVxU193H8c9vZtgXkUVBwJ24Kypx1yzGNDGL2dqapVnarE2Tpk2bJunzJGmapE3aZm/1MWbtkjQxm9njFmMWTXBDEFTEDQRFRVbZz/PHDDjgIAMCwwy/9+s1L4Z7LzM/rvjlcO6554gxBqWUUt7P4ukClFJKdQwNdKWU8hEa6Eop5SM00JVSykdooCullI+weeqNo6OjzcCBAz319kop5ZXWr19/yBgT42qfxwJ94MCBpKameurtlVLKK4nInpb2aZeLUkr5CA10pZTyERroSinlI9wKdBH5pYiki0iGiNzlYv+ZIlIsIpscjwc6vlSllFIn0+pFUREZDdwETAKqgU9F5ENjTHazQ9cYYy7shBqVUkq5wZ0W+ghgnTGmwhhTC6wGLuvcspRSSrWVO4GeDswUkSgRCQbmAokujpsqIptF5BMRGeXqhUTkZhFJFZHUwsLCUyhbKaVUc60GujEmE3gc+Bz4FNgE1DU7bAMwwBgzDngOeK+F11pkjEkxxqTExLgcF9+qbQWl/PWzbRwpr27X1yullK9y66KoMeZFY8xEY8wsoAjY3mx/iTGmzPH8Y8BPRKI7vFpg16Eynl+VzYGSys54eaWU8lrujnLp4/jYH3v/+X+a7Y8VEXE8n+R43cMdW6pdaIAfAKWVtZ3x8kop5bXcvfX/bRGJAmqA240xR0XkVgBjzELgCuA2EakFjgHzTScthRQaaC+5rKqmM15eKaW8lluBboyZ6WLbQqfnzwPPd2BdLQoNsJesLXSllGrK6+4UDW9soWugK6WUM68L9MYuF22hK6VUE14X6EF+ViyiLXSllGrO6wJdRAgNsGkfulJKNeN1gQ4QFuinga6UUs14ZaCHBth02KJSSjXjnYEeaNM+dKWUasYrAz0s0KajXJRSqhmvDPTQABul2kJXSqkmvDLQtYWulFIn8spAt18U1UBXSilnXhroflRU11FbV+/pUpRSqtvwzkB33P5fXtV8nQ2llOq5vDLQwxpmXNSx6Eop1cg7A11nXFRKqRN4ZaDrjItKKXUi7wz0xi4XDXSllGrglYEepi10pZQ6gbuLRP9SRNJFJENE7nKxX0TkWRHJFpE0EZnQ8aUe17BQtPahK6XUca0GuoiMBm4CJgHjgAtFZGizw84HkhyPm4EFHVxnEw196KWVOspFKaUauNNCHwGsM8ZUGGNqgdXAZc2OmQe8ZuzWAhEiEtfBtTYK9rMiol0uSinlzJ1ATwdmikiUiAQDc4HEZsfEA/ucPs91bGtCRG4WkVQRSS0sLGxvzVgsohN0KaVUM60GujEmE3gc+Bz4FNgEtOsWTWPMImNMijEmJSYmpj0v0SgsQCfoUkopZ25dFDXGvGiMmWiMmQUUAdubHZJH01Z7gmNbp9FFLpRSqil3R7n0cXzsj73//D/NDlkKXOsY7TIFKDbG5Hdopc3ojItKKdWUzc3j3haRKKAGuN0Yc1REbgUwxiwEPsbet54NVAA3dEaxzkID/Sg5pqNclFKqgVuBboyZ6WLbQqfnBri9A+tqVViAjbyiiq58S6WU6ta88k5R0C4XpZRqzmsDXZehU0qpprw20EMDbZRX11FXbzxdilJKdQveG+iOGRfLq7WVrpRS4MWBrjMuKqVUU14b6DrjolJKNeW9ga4zLiqlVBPeG+gNqxZpl4tSSgFeHOjhulC0Uko14bWBrgtFK6VUU94b6AHaQldKKWdeG+gh/tqHrpRSzrw20BtWLdIWulJK2XltoIO920WHLSqllJ13B7quWqSUUo28OtDDAm3ah66UUg5eHejah66UUse5u6bor0QkQ0TSReR1EQlstv96ESkUkU2Ox42dU25TOie6Ukod12qgi0g8cCeQYowZDViB+S4O/a8xJtnxWNzBdbqkLXSllDrO3S4XGxAkIjYgGNjfeSW5LzTAT1voSinl0GqgG2PygL8Ce4F8oNgY87mLQy8XkTQRWSIiiR1cp0uhgTbKqmup11WLlFLKrS6X3sA8YBDQDwgRkWuaHfYBMNAYMxZYBrzawmvdLCKpIpJaWFh4apUDYQE2jNFVi5RSCtzrcjkH2GWMKTTG1ADvANOcDzDGHDbGVDk+XQxMdPVCxphFxpgUY0xKTEzMqdQNOK1apP3oSinlVqDvBaaISLCICDAbyHQ+QETinD69uPn+zqIzLiql1HG21g4wxqwTkSXABqAW2AgsEpGHgVRjzFLgThG52LH/CHB955V8XOMiF9pCV0qp1gMdwBjzIPBgs80POO2/D7ivA+tyiy4UrZRSx3n5naK6ULRSSjXw7kDXhaKVUqqRdwe6LhStlFKNfCLQtctFKaW8PNCtFiHE36oXRZVSCi8PdNBFLpRSqoH3B3qATcehK6UUvhDogTrjolJKgQ8EepguFK2UUoAPBLoucqGUUnbeH+i6DJ1SSgE+EOhhgXpRVCmlwBcC3dHlYoyuWqSU6tm8PtBDA+2rFlVU13m6FKWU8ijvD3SdcVEppQBfCHSdcVEppQAfCPQwnXFRKaUAHwj0UF0oWimlADcDXUR+JSIZIpIuIq+LSGCz/QEi8l8RyRaRdSIysDOKdUWXoVNKKbtWA11E4oE7gRRjzGjACsxvdtjPgCJjzFDgKeDxji60JbpQtFJK2bnb5WIDgkTEBgQD+5vtnwe86ni+BJgtItIxJZ5cWMMoF22hK6V6uFYD3RiTB/wV2AvkA8XGmM+bHRYP7HMcXwsUA1HNX0tEbhaRVBFJLSwsPNXaAQgJsALah66UUu50ufTG3gIfBPQDQkTkmva8mTFmkTEmxRiTEhMT056XOIHNaiHIz6rDFpVSPZ47XS7nALuMMYXGmBrgHWBas2PygEQAR7dML+BwRxZ6MrpqkVJKuRfoe4EpIhLs6BefDWQ2O2YpcJ3j+RXAStOFk6vY50TXQFdK9Wzu9KGvw36hcwOwxfE1i0TkYRG52HHYi0CUiGQDvwbu7aR6XQrTFrpSSmFz5yBjzIPAg802P+C0vxL4YQfW1SY6J7pSSvnAnaKgqxYppRT4TKD7aR+6UqrH84lADwvUhaKVUsonAj1UVy1SSikfCfRAG/UGjtXoqkVKqZ7LJwJdZ1xUSikfCXSdcVEppXwk0LWFrpRSPhLoulC0Ukr5TKDrQtFKKeUTgd7Q5aI3FymlejKfCPSGFrp2uSilejLfCHS9KKqUUr4R6H5WC4F+Fm2hK6V6NJ8IdHBM0KWBrpTqwXwm0MN0TnSlVA/nM4EeGqAzLiqlejafCnTtQ1dK9WStBrqIDBORTU6PEhG5q9kxZ4pIsdMxD7T0ep0lNFAXilZK9WytrilqjNkGJAOIiBXIA951cegaY8yFHVue+3ShaKVUT9fWLpfZwE5jzJ7OKOZUhGmXi1Kqh2troM8HXm9h31QR2Swin4jIKFcHiMjNIpIqIqmFhYVtfOuTC3WMctFVi5RSPZXbgS4i/sDFwFsudm8ABhhjxgHPAe+5eg1jzCJjTIoxJiUmJqY99bYoNMCP2npDVW19h76uUkp5i7a00M8HNhhjDjTfYYwpMcaUOZ5/DPiJSHQH1eiWhtv/S3ToolKqh2pLoF9JC90tIhIrIuJ4PsnxuodPvTz3hQXofC5KqZ6t1VEuACISAswBbnHadiuAMWYhcAVwm4jUAseA+aaLO7N1xkWlVE/nVqAbY8qBqGbbFjo9fx54vmNLaxtdhk4p1dP5zp2igbpQtFKqZ/OZQA9rWFdUW+hKqR7KZwK9cZELbaErpXoonwn0kAAroAtFK6V6Lp8J9ACbFX+bRfvQlVI9ls8EOjjmc9E+dKVUD+Vbga4zLiqlejCfCvRQXYZOKdWD+VagB9i0D10p1WP5WKD7aQtdKdVj+VSghwXaKK3SYYtKqZ7JpwI9VEe5KKV6MN8KdMcoF121SCnVE/lWoAfYqKnTVYuUUj2TTwV6uM7nopTqwXwq0EN1TnSlVA/mW4HeMIWuttCVUj2QjwW6ZxeK/tfaPZzz5Gryjh7zyPsrpXq2VgNdRIaJyCanR4mI3NXsGBGRZ0UkW0TSRGRC55XcMneWoXt/Ux5PfJrFtoLSDnvfunrDHz7I4H/eSyf7YBlvr8/tsNfuCQpLq7j9PxsortB7CJQ6Fa2uKWqM2QYkA4iIFcgD3m122PlAkuMxGVjg+NilWlsoOj2vmLvf3ExtveEfX+xkVL9wLh0fz8XJ/egTFtiu9yytrOHO1zeyalshP50+iC15R3lvUx53nD0UEWn399KTrNp2kI/S8pk3rh/njor1dDlKea22drnMBnYaY/Y02z4PeM3YrQUiRCSuQypsg5OtWlRVW8fdb26md4g/y399Bg9eNBKrRXjko0ym/mkl17/8HUs376eyps7t99t3pIIrFnzLlzsO8eilo3ngopFcOj6BnMJyMvaXdNj35euy8u1/Le0r0q4qpU5Fqy30ZuYDr7vYHg/sc/o817Et3/kgEbkZuBmgf//+bXzr1jV0uZS66HJ5ctl2th0o5eXrT2don1CG9gnlhumDyD5Yyjsb8nhvYx53vr6R0AAbc0b2Zc7IvpxxWgwhAa5P0fo9Rdzyz1Sqaut59YZJzEiKBmDumFgeXJrOexvzGB3fq8O/R1+UmW//5bfvSIWHK1HKu7kd6CLiD1wM3NfeNzPGLAIWAaSkpHT47ZwBNiv+VssJLfTvdx9h0Zc5XDkpkbOG92myb2ifMO45bzi/OXcYa3cd5t0NeSzPPMC7G/Pwt1mYMTSac0f2ZfaIvsSEBQD2fvjfLkkjrlcgb9xs/wXRICLYnzOH9WHp5v3cN3cEVot2u5yMMYasAnug5xZpoCt1KtrSQj8f2GCMOeBiXx6Q6PR5gmNbl2s+J3p5VS13v7mZhN5B/P6CkS1+ncUiTBsSzbQh0dTW1ZO6p4hlWw/w+dYCVmYdRGQLE/r3ZlB0CEvW5zJpUCT/d81Eeof4n/BalyTHs2zrAdbmHGb60OhO+T59xYGSKoocF0P3HdEuF6VORVsC/Upcd7cALAV+ISJvYL8YWmyMyW/h2E4VGmBrslD0Yx9nsq+ogjdumtJ40bQ1NquFKYOjmDI4iv+5YARZBaWN4b5kfS4/nJjAo5eOwd/m+hLE7BF9CA2w8d7GPA30VmQ6Wuej+oWz61A5xhi9mKxUO7mVcCISAswBbnHadiuAMWYh8DEwF8gGKoAbOrxSN4UGHF+GbvX2Qv69bi83zRzE5MFR7Xo9EWFEXDgj4sK5c3YSlTV1BPpZT/o1gX5WfjAqlk/TC/jjJaNbPb4na7gges6IvjyzYgdHyquJCg3wcFVKeSe3RrkYY8qNMVHGmGKnbQsdYY5jdMvtxpghxpgxxpjUziq4NaGBNkoraymuqOGeJZtJ6hPK3ecO67DXdzecLxnfj9KqWlZlHeyw9/ZFmfklxEcEMapfOKAjXZQ6FT51pyhAmKOF/uDSdA6XVfPkj5I90kKeNiSamLAA3tvk/qWExWty+Puq7E6sqvvJKihhRFwYiZHBgI50UepU+F6gB9rYVlDKe5v284uzhzImwTNDB60W4aKx/ViVVejWHZDr9xzh0Y8z+ceqbGrqesb0v1W1dewsLGd4bPjxQNeRLkq1m88Femigjdp6w9iEXtx+1lCP1nLJ+H5U19XzSfrJrw9X1tTx2yVp+FkslFfXsXHv0S6q0LN2HCijrt4wPC6M0AAbvYP9dKSLUqfA5wK9b1ggATYLT/5oHH5Wz357Y+J7MTg6pNVul6eWbyensJxn5idjtQhrdhR2UYWeleWYT2dEnL3/PDEyWMeiK3UKfC7Qb5o1mC9+eyZD+4R5uhREhHnJ8azbdYT8Ytctz037jvLClznMPz2R88fEMS6hF2t2HOriSj0jK7+EAJuFgVEhACT2DtY+dKVOgc8FeqCflbheQZ4uo9G85H4YA0s37T9hX1VtHb99azN9wwO5/4IRAMxMiiEt92iPmHkws6CEYbFhjXfTJkQGkXf0GHX1uiasUu3hc4He3QyMDmFcYgTvuwj051Zks+NgGY9dNobwQPviHDOToqk38M1O326lG2PIzC9lRGx447bE3sHU1BkOlFR6sDKlvJcGehe4JLkfW/NL2HHg+Bzs6XnFLFi9k8snJHDWsOPzy4xLjCA0wMaabN8O9MLSKo6UVzM87njXmA5dVOrUaKB3gQvH9sNqkcaLo9W19fzmrc1EhfjzwIVN55fxs1qYOiTK5y+MZjouiA5v0kK3d5XpzUVKtY8GeheICQtg+tBo3t+0H2MM//gim6yCUh69dAy9gv1OOH5mUjT7jhxjz+FyD1TbNbIcU+aOcGqhx/cOQkRb6Eq1lwZ6F7kkuR+5Rcf417q9PL8ym3nJ/Zgzsq/LY2c4JvTy5dEumfklxPUKJCL4+GyVATYrseGBenORUu2kgd5Fzh0VS6Cfhf99L52IYD8eumhUi8cOig4hPiLIp7tdsgpKGR574tDSxN7B5OrNRUq1iwZ6F7GvhGRfL/OP80a7nEe9gYgwMymab3YeptYHpwGorq0n+2BZ4w1FzhIig7SFrlQ7aaB3od+eO4y//XAc549pfbnVGUnRlFbWsjm3uNVjvU32wTJq6w3DXQR6Yu9gCkoqqap1f21XpZSdBnoX6h8VzOUTE9w6dvqQaETgKx/sR29Ycm5knIsul8hgjIH9R3UsulJtpYHeTfUO8WdMfC++yva9fvSsglL8nW75d9Y4dFFHuijVZhro3djMpGg27D3aZEk9X5CZX8JpfUOxuZg8TafRVar93Ap0EYkQkSUikiUimSIytdn+M0WkWEQ2OR4PdE65PcuMoTHU1RvW5hxp9dhj1XXc905ak7tRu6vmt/w76xseiJ9V2KstdKXazN0W+jPAp8aY4cA4INPFMWuMMcmOx8MdVmEPNmFABEF+VreGL/7ls228/t0+XvxqVxdU1n6FpVUcKqtyeUEU7AuDxEcE6dBFpdqh1UAXkV7ALOBFAGNMtTGmZ6zA4GEBNitTBke2emH0+91HePmbXQTYLHyaUdCtVzxquCA6wsUY9AaJkcHa5aJUO7jTQh8EFAIvi8hGEVksIidezYKpIrJZRD4RkZbvmlFtMiMphpxD5S0u/HCsuo57lqQRHxHE45eP5WhFDV9344m9svIdc7i00EIHSNB50ZVqF3cC3QZMABYYY8YD5cC9zY7ZAAwwxowDngPec/VCInKziKSKSGphoe+N3ugMM5Ps0wC01Er/6+fb2HWonCcuH8t5o2MJC7DxUdrJl7zzpMz8EvqGBxB5khurEiODKKqooayqtgsrU8r7uRPouUCuMWad4/Ml2AO+kTGmxBhT5nj+MeAnItHNX8gYs8gYk2KMSYmJiTnF0nuGpD6h9A0PcDmdburuI7z09S6umdKfaUOjCfSzMmdkXz7LKKC6tnt2u2QWlDaZYdGVxN46ja5S7dFqoBtjCoB9IjLMsWk2sNX5GBGJFRFxPJ/keN3DHVxrjyQizBgaw9fZh5qs5NOwsHR8RBD3nT+icfuF4+IoqaztluPXa+rqyT5Y6vKWf2c6L7pS7ePuKJc7gH+LSBqQDDwmIreKyK2O/VcA6SKyGXgWmG+M0XXEOsjMpGiOVtSQsf/4NAB/c+pqCQmwNW6fMTSG8EAbH27uft0uOwvLqKkzTabMdUXnRVeqfWytHwLGmE1ASrPNC532Pw8834F1KSfTnabTHZsQwfo9R1j81S6unmzvanHmb7Pwg1GxfJpeQGVNHYF+Vk+U7FLDBdHWWuiRIf4E+1u1ha5UG+mdol4gJiyAEXHhrNlRaO9qeSuNfr2CuG/uCJfHXzA2jtKq2m43n3pmQQn+VguDol0NkjpOROzT6OrQRaXaRAPdS8xMimb9niIe+WgrOYfKeeKKsYQGuP4Da/rQaCKC/fgw7cSFqT0pM7+UoX1C8XNxy39ziZHB7NObi5RqEw10LzEzKZqaOsO/1u7lqsn9G7thXPGzWjhvVCzLtx6gsqb7TEOblV/SandLg0THvOh6KUYp92mge4nTB0YSYLMQHxHE/S10tTi7cGw/yqvr+GLbwS6ornWHy6o4WFrV6gXRBom9g6moruNIeXUnV6aU73DroqjyvEA/K89dOZ7+UcEtdrU4mzI4kqgQfz5My+e80a0vqHGqXv9uLxaByyYkuOxSySpw3CHayhj0BsdnXTxGVGhAxxWqlA/TQPci546KdftYm9XCeaNjeWdDHhXVtQT7d94/dVruUe57ZwsAz63M5o6zh54Q7Jn5jjlc3G2hRx6fFz05MaKDK1bKN2mXiw+7YGwcx2rqWJXVeTcZGWN47ONMIkP8WXD1BCJD/Pnd21uY/bfVvJm6r3FN1Mz8UmLCAtxubTfeLaojXZRymwa6D5s8KIro0IBOHe2yIvMga3OOcNc5SZw/Jo73b5/Oi9elEB5k454lacx+cjVL1ueSsb+Y4SeZYbG5kAAbkSH+OtJFqTbQQPdhVoswd0wsK7MOUt4JE13V1tXzp08yGRwdwpWT+gP2MeSzR/Tlg1/M4IVrUwjxt/GbtzaTVVDKSDdHuDRI7B3k9WPR/7l2D3/4IMPTZahu5Mll2/luV+uL1rSHBrqPu2BMHFW19azI6vjRLv9N3cfOwnJ+d/7wEy6EighzRvbloztnsPCaiZw5LIYLxrbt4mxCpHdPo5ueV8xDSzN4+evdrN9T1OXv/1lGAWf+ZRXpecWtH6y6xJbcYp5dsYN1OZ0z1ZUGuo87fWAkfcIC+HBzx3a7lFXV8tSy7Zw+sDfnjuzb4nEiwnmjY3nlhkmMTWjbxc3E3sHkHT3WZFIyb1FTV89vl6QRFeJPRLAfC77I7tL3N8bw1LLt7D5cwU9eXMe2gu6/NKE3qW/nz+TTy7fTK8iP66cP7NiCHDTQfZzFIswdE8cX2ws7dLHpRat3cqismvvnjsAx0WaHS4wMoqbOcKCkslNevzMt/GInmfklPHrpGK6fNpDlmQcbV2vqCl9nHyaroJQ7zh6Kn9XC1YvXsbOwrMve35c9tDSDc5/+ss3dmJv3HWVF1kFumjmIsEC/TqlNA70HuGhcHNW19SzPPNDiMeVVtW7fxFNQXMmiNTlcODaO8f17d1SZJ/DWedG3Hyjl2ZU7uHhcP+aM7Mv10wYS7G9lwRc7u6yGF7/KITrUn9vPGsp/bpoCGK56YS17Dpd3WQ2+6MO0/bzyzW6yD5bx/Kq2/dX19PLtRAT7cd20gZ1THBroPcL4xN7E9QpsXMmoorqW73cf4aWvdvGr/27inCdXM/qhz5j06HL+b/XOVv+cfHLZNurr4XfnDe/Uup1vLvIWtXX1/PatzYQF+vHgRSMBiAj25+rJ/flg8372Hu78X07ZB8tYta2Qa6YMINDPytA+ofz7xilU19Zz1QvrvP5Cs6fkFlVw3ztbSE6M4JLkfixek+P2Xz0b9xaxalshN80c3Gmtc9BA7xEsFuGCMXGs3l7InCdXM/rBz/jhwm95+MOtfLPzEAOjgrlr9mnMHtGHP32SxXUvf8fBUtfdHJn5Jby1Ppdrpw5oDNzO0i8iEBHvaqG/9PUuNucW84eLRzUZc3/jzMHYLBb+78vOb6W//PUu/G0WrpkyoHHbsNgw/vmzyZRW1nDVC+soKPa+bixPqq2r5643NmEMPDt/PL+/YCSBflYeWprh1nxDTy/fQe9Obp2DBnqP8ePTExkQFUJiZDB3nJ3E4mtTWHf/bNbdfw6LrzudX56TxMJrJvLYpWP4btcR5j6zxuU8MH/6JIvwQD9+cfbQTq85wGYlNjzQawI9p7CMv32+nXNH9uXCZiN6+oYHcvnEBN5KzeVgJ14TKCqv5u0NuVyS3I/oZjdxjY7vxas/ncSR8mquemFti7+01YmeX5VN6p4iHr10NP2jgokJC+DuOaexZschPk0vOOnXrt9TxOrthdw8a4hb03acCg30HiKpbxjLf30GL11/Or+acxrnjOxL3/DAJseICFdN7s8Hd8wgOjSA61/+nkc+3EpVrX3Gxi+3F/Ll9kLuOHsoEcEtL/LckRJ7B3vF3aL19YbfvZ1GgM3CI5eMdnmh+NYzBlNbX8+LX+3qtDr+891eKmvq+emMQS73j+/fm5dvOJ384kquWbyu3ZOfVdbU8cOF33Djq6l8sHk/x6q7z6yeHe373Ud4dsUOLpsQz7zk+Mbt10wZwPDYMP744VYqqlu+QPrMih1Ehvhz7dQBLR7TUTTQ1QlO6xvGe7dP59qpA1j81S4uX/AN2QfLeOzjTBIjg/hJF/xgNkiIDPKKu0X/uXYP3+8u4oGLRtGn2S/KBgOiQrhwbD/+tXYPxRUdN+KoQXVtPa99u5uZSdEnnQTt9IGRvHhdCnsOV3Djq9+3a4ri9zbm8f3uIjbuLeKO1zeS8sgyfvXfTXyx7SA1dd1zgfL2KK6o4a43NpEYGczD80Y32WezWvjjJaPZX1zJ8ytdXyBdv6eIL7cXcsuswU2WiuwsbgW6iESIyBIRyRKRTBGZ2my/iMizIpItImkiMqFzylVdJdDPysPzRrPoJxPJLTrGD57+kqyCUu75wXACbF23rF3/yGAOlFY2/pXgCceq604aevuOVPD4p1mccVoMl0+Ib/E4gNvOHEJ5dR2vfru7Y4sEPt6Sz4GSqhZb586mDY3m4Xmj2LD3KCvbeNNZfb1h0ZocRvUL57vfn8PrN03h4uR+rMg8wPUvf8+Ux1bwwPvpbNjb9TdTdSRjDPe/u4UDJZU8O3+8y+6S0wdGctmEeF5Yk0OOiwukTy/fTnSof5c1gtz9lfEM8Kkx5goR8QeaXw07H0hyPCYDCxwflZc7d1QsYxJ6cc+SNCwiJ/QNd7bE3sEYA3lFxxgcE9ql711TV89fP9vGojU5BPlZGRgVwqCYEAZHhzDI6XHfO1sQ4LHLxrQ6Jn9EXDizh/fh5a93cePMQR02C6YxhsVf5TAkJoQzkmLc+prLJt/educAAA7gSURBVCTw3MpsnluZzdnD+7h9P8GKrIPkFJbzzPxkrBZh6pAopg6J4qGLR7F6WyHvb9rPf7/fx2vf7uHXc07jztlJp/Ktecxbqbl8tCWf3503nHEnmfHzvvNHsCzjAA8uzeC1n05qPI+pu4+wZschfj93RKfOduqs1XcRkV7ALOB6AGNMNdC8420e8JqxN2PWOlr0ccaY7rf0vGqzuF5B/PNnnvn97Dx0sT2BvvtQOYu/yiE6NICbZrr/Z+++IxXc+cZGNu49ymUT4okI8mfXoTIy8or5NL3ghLtXH7lkNPERQW699s/PGsLlC77l9e/28TM3WtPu+G7XEdLzSnj00tFYLO4Fs5/Vwm1nDuH376bzdfZhZiS1vAqWs0Vf7iQ+IogLxjT95R5gs3LuqFjOHRVLaWUND76fwZPL7HdGdvbojo62s7CMB5dmMG1IFLfMGnzSY2PCAvj1uafxhw+28llGQeP6A085WudXT+nfFSUD7rXQBwGFwMsiMg5YD/zSGON8h0I8sM/p81zHtiaBLiI3AzcD9O/fdd+k8l7O86K3RX7xMZ5dkc2bqfuwWoTq2nr+vW4vv/3BMK6YkHDS0Ps0vYB7lmzGGPj7VRNOmIOmurae3KIKdh0qZ9ehcuqN4apJ7v88TxwQyeRBkbzwZQ4/mTIAf9uJPZ/GGDbsLeLDtHxmJkVz9vCWp1cAePGrXUQE+3HZ+AS36wC4YmICz67YwXMrd7gV6Bv2FtmvFVw4EttJ1oYNC/TjiSvGUlJZy4NLM+gV5Mcl40/eHdVdVNXWcefrGwn0s/DUj5Pd+gX5kykD+O/3+/jjh5mccVoftuQV83X2Yf7ngq5rnYN7feg2YAKwwBgzHigH7m3PmxljFhljUowxKTEx7v1ZqHq2vmGB+Fstbo90OVRWxcMfbOWMv3zBkvX7uGZyf7665yze+fk04iOCuGdJGhf//SuXkyNV1dbx0NIMbv3XegZGh/DRnTNdTijmb7MwOCaU2SP6cuPMwdw8a4jbreIGPz9rKAUllby7MbfJ9sqaOt78fh8XPvcVly/4lle+2c1PX0nlniWbW5y6Yc/hcpZlHuDqyf0J8m/b9Y0Am5VbZg1h3a4jbs0AuGh1Dr2C/Pjx6YmtHmuzWnj+qvFMHRzF3W9tZvnWlu9U7i7q6g2/eSuNjP0lPHHFuBNGgrXEZrXw8LzR5B09xt9XZfPUsu1EhwZw9eSuG0AA7gV6LpBrjFnn+HwJ9oB3lgc4/wsnOLYpdUosFiG+dxC5rYx0Ka6o4S+fZTHriVW88s0u5o3rx8q7z+QP80bTJzyQCf178+7Pp/HM/GSOlFXz40Vrue1f6xvv3Nx9qJzLF3zDK9/s5mczBrHk1mn0j+q8G6dmJUUzOj6chatzqKs37D1cwWMfZzLlTyu45+00ausMj1wymo3/O4efnzmEJetzOe/pNXy788RfRC9/vRubRbh26sB21XLlpP5Eh/q3eiv7rkPlfLa1gGum9He76yrQz8oL16Uwql84t/9nA2s7aZbBjtAw9PSDzfu59/zhzDnJpHOuTBoUyWXj41mweiff5hzmtjOHtPkX7Klq9V/FGFMgIvtEZJgxZhswG9ja7LClwC9E5A3sF0OLtf9cdZSE3kEntNCLK2rIKihh24FSMvNL+ShtPyWVtVw4No5fzTmNIS7620WEecnxnDsylsVrcvjHFztZkXmQecn9+CS9AKtFeOHalDb/R24PEeHnZw7l5//ewKX/+JotecVYRPjBqL5cO3UgkwdFNl5cu+e84cwe0Ye739zMlS+s5afTB3HPecMI9LNSUlnDW6n7uHBsP7dbk80F+Vu5ceZg/vxJFpv3HW3xAuDiNTn4WSxt7g8PDbDxyg2T+NH/fcuNr6byxs1TGB3f66Rfs+NAKVW19a0e11GMMTywNJ0l63O565wkbj1jSLte5965w1m29QBB/launtz13crizhhUEUkGFgP+QA5wA/BjAGPMQrH/5D0PnAdUADcYY1JP9popKSkmNfWkhygFwP3vbuGDTfu5cnJ/sgpK2VZQwoGSqsb9vYL8mDo4ijtnJzGyn/uLaBwoqeSJT7fx9oZcJg7ozbNXjnf7wmZHqKs3XPDsGg6VVXHlpP5cNbk/cb1afv+K6lr+9HEW/1y7hyExITz142TW5Rzh0Y8z+fCOGacUfmVVtUz/80omDYrkhWtTTth/qKyK6X9eyaXj4/nz5WPb9R75xce4YsG3HKup481bpjK0z/FfusYYsgpK+WRLPh+nF5B90D4E8PzRsdw/d0SnTjNhjOGRjzJ58atd3HLGYO49b/gpzSCalnsUP6uFEW1c0MVdIrLeGHPiPxJuBnpn0EBX7vrnt7v53/cz8LdaGNonlOGxYQxzPIbHhtM3POCU/gMWFFcSHep/0ot8naWqtg6LyAkLhJzMl9sLuWdJGoVlVQT7WxkRF86bt0xt/Qtb8czyHTy1fDuf/HLmCWH05LLtPLtiB8t/fUaTIG6rXYfK+eHCb/C3WnjrtmkUlVfzSXo+n2wpIOdQORaxL504d0wsRRU1/OOLbIyBW88Ywq1nuNeFsfdwBcsyDzC0TyizkqJb/dn4y2dZ/H3VTq6fNpAHLxrZadNBdxQNdOXV6uoN+48eI65XoEdCtzsqrqjhoQ8yeHdjHi9dn9LqKBh3X3P64ys5c1gMz191/DLZseo6pv15BRMHRLL4Opc50iYZ+4uZv2gtlTV11NQZ+1j2wVGcPyaWH4yKbTIHzf6jx3js40w+TMsnPiKI318wgvNHx54QuofKqvhw837e37yfjXuPNm4fEhPC9dMHcfmEeJejTZ5fuYO/fr6dKycl8tilrd9H0B1ooCvlo46UVxMZ0nHz6jzxaRYLVu9k+a/PaLwO8dq3u3ng/QzeunUqpw+M7JD3Wb+niJe+3sWspGjmjIxt9XtYm3OYh5ZmkFVQytTBUTx48UgSegfzeUYB723az9fZh6irNwyPDeOS8fGcNyqWjfuKeOmr3WzJKyY80Mb8Sf25duoAEhzz7C9ek8MjH2Vy2fh4/vrDcW0eqeQpGuhKKbccLqti+uMruWBMP/72o3HU1RvO+usXRIX6885t0zzagq2tq+f17/fxt8+3UXKsBn+bhcqaeuIjgpiX3I95yfEMiw1r8jXGGNbvKeLlr3fzaUYBxhh+MCqWITGhPL8qmwvGxPHM/GSv+svvZIHedSPelVLdXpRj7PQr3+zmrnOSSMstZu+RCu6fe2oXCjuCzWrhJ1MGcOGYOBau3kllTR0XjevHxAG9W6xNREgZGEnKwEjyjh7jtW9388Z3+/gkvYBzRvThaS8L89ZoC10p1cSBkkpmPr6KyycmsHV/McXHalhx95lYvaRLojUV1bWsyznCtKFRXTrRXEfRFrpSym19wwP50ekJ/HvdXoyxz1PjK2EOEOxv46zhfTxdRqfwnb81lFId5pZZQ7CKEBXizxUT2zY/jPIcbaErpU6QGBnMHy8ZTWSIP4F+3tct0VNpoCulXLqyDTNIqu5Bu1yUUspHaKArpZSP0EBXSikfoYGulFI+QgNdKaV8hAa6Ukr5CA10pZTyERroSinlIzw2OZeIFAJ72vnl0cChDiynI2lt7dOda4PuXZ/W1j7eWtsAY0yMqx0eC/RTISKpLc025mlaW/t059qge9entbWPL9amXS5KKeUjNNCVUspHeGugL/J0ASehtbVPd64Nund9Wlv7+FxtXtmHrpRS6kTe2kJXSinVjAa6Ukr5CK8LdBE5T0S2iUi2iNzr6XqcichuEdkiIptExKMrYIvISyJyUETSnbZFisgyEdnh+Ni7G9X2kIjkOc7dJhGZ66HaEkVklYhsFZEMEfmlY7vHz91JavP4uRORQBH5TkQ2O2r7g2P7IBFZ5/j/+l8R8e9Gtb0iIruczltyV9fmVKNVRDaKyIeOz9t33owxXvMArMBOYDDgD2wGRnq6Lqf6dgPRnq7DUcssYAKQ7rTtCeBex/N7gce7UW0PAb/pBuctDpjgeB4GbAdGdodzd5LaPH7uAAFCHc/9gHXAFOBNYL5j+0Lgtm5U2yvAFZ7+mXPU9WvgP8CHjs/bdd68rYU+Ccg2xuQYY6qBN4B5Hq6pWzLGfAkcabZ5HvCq4/mrwCVdWpRDC7V1C8aYfGPMBsfzUiATiKcbnLuT1OZxxq7M8amf42GAs4Elju2eOm8t1dYtiEgCcAGw2PG50M7z5m2BHg/sc/o8l27yA+1ggM9FZL2I3OzpYlzoa4zJdzwvAPp6shgXfiEiaY4uGY90BzkTkYHAeOwtum517prVBt3g3Dm6DTYBB4Fl2P+aPmqMqXUc4rH/r81rM8Y0nLdHHeftKREJ8ERtwNPAPUC94/Mo2nnevC3Qu7sZxpgJwPnA7SIyy9MFtcTY/5brNq0UYAEwBEgG8oG/ebIYEQkF3gbuMsaUOO/z9LlzUVu3OHfGmDpjTDKQgP2v6eGeqMOV5rWJyGjgPuw1ng5EAr/r6rpE5ELgoDFmfUe8nrcFeh6Q6PR5gmNbt2CMyXN8PAi8i/2Hujs5ICJxAI6PBz1cTyNjzAHHf7p64AU8eO5ExA97YP7bGPOOY3O3OHeuautO585Rz1FgFTAViBARm2OXx/+/OtV2nqMLyxhjqoCX8cx5mw5cLCK7sXchnw08QzvPm7cF+vdAkuMKsD8wH1jq4ZoAEJEQEQlreA6cC6Sf/Ku63FLgOsfz64D3PVhLEw1h6XApHjp3jv7LF4FMY8yTTrs8fu5aqq07nDsRiRGRCMfzIGAO9j7+VcAVjsM8dd5c1Zbl9AtasPdRd/l5M8bcZ4xJMMYMxJ5nK40xV9Pe8+bpq7vtuBo8F/vV/Z3A7z1dj1Ndg7GPutkMZHi6NuB17H9+12Dvg/sZ9r65FcAOYDkQ2Y1q+yewBUjDHp5xHqptBvbulDRgk+Mxtzucu5PU5vFzB4wFNjpqSAcecGwfDHwHZANvAQHdqLaVjvOWDvwLx0gYTz2AMzk+yqVd501v/VdKKR/hbV0uSimlWqCBrpRSPkIDXSmlfIQGulJK+QgNdKWU8hEa6Eop5SM00JVSykf8P5swPO1QLbHWAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"%matplotlib inline\n",
"\n",
"plt.figure()\n",
"plt.plot(losses)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 预测\n",
"用训练好的模型进行预测。"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"the input words is: whiles, thou\n",
"the predict words is: art\n",
"the true words is: art\n"
]
}
],
"source": [
"import random\n",
"def test(model):\n",
" model.eval()\n",
" # 从最后10组数据中随机选取1个\n",
" idx = random.randint(len(trigram)-10, len(trigram)-1)\n",
" print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n",
" x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n",
" x_data = paddle.to_tensor(np.array(x_data))\n",
" predicts = model(x_data)\n",
" predicts = predicts.numpy().tolist()[0]\n",
" predicts = predicts.index(max(predicts))\n",
" print('the predict words is: ' + idx_to_word[predicts])\n",
" y_data = trigram[idx][1]\n",
" print('the true words is: ' + y_data)\n",
"test(model)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册