未验证 提交 04af4d42 编写于 作者: C Chen Long 提交者: GitHub

add tow models with develop-paddle (#881)

上级 7df3c81e
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST数据集使用LeNet进行图像分类\n",
"本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n",
"手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop版本。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0.0\n"
]
}
],
"source": [
"import paddle\n",
"print(paddle.__version__)\n",
"paddle.disable_static()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 加载数据集\n",
"我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"download training data and load training data\n",
"load finished\n"
]
}
],
"source": [
"print('download training data and load training data')\n",
"train_dataset = paddle.vision.datasets.MNIST(mode='train')\n",
"test_dataset = paddle.vision.datasets.MNIST(mode='test')\n",
"print('load finished')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"取训练集中的一条数据看一下。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_data0 label is: [5]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkjCFmRdj3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H97ghom3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A/ZGSLFUZBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbO5tLBIAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n",
"train_data0 = train_data0.reshape([28,28])\n",
"plt.figure(figsize=(2,2))\n",
"plt.imshow(train_data0, cmap=plt.cm.binary)\n",
"print('train_data0 label is: ' + str(train_label_0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.组网\n",
"用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.nn.functional as F\n",
"class LeNet(paddle.nn.Layer):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
" self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n",
" self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n",
" self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n",
" self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n",
" self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.relu(x)\n",
" x = self.max_pool1(x)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.max_pool2(x)\n",
" x = paddle.reshape(x, shape=[-1, 16*5*5])\n",
" x = self.linear1(x)\n",
" x = F.relu(x)\n",
" x = self.linear2(x)\n",
" x = F.relu(x)\n",
" x = self.linear3(x)\n",
" x = F.softmax(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练方式一\n",
"通过`Model` 构建实例,快速完成模型训练"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"from paddle.static import InputSpec\n",
"from paddle.metric import Accuracy\n",
"inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n",
"model = paddle.hapi.Model(LeNet(), inputs, labels)\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"\n",
"\n",
"model.prepare(\n",
" optim,\n",
" paddle.nn.loss.CrossEntropyLoss(),\n",
" Accuracy(topk=(1, 2))\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 使用model.fit来训练模型"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n",
" return (isinstance(seq, collections.Sequence) and\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n",
"step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n",
"step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n",
"step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n",
"step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n",
"step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n",
"step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n",
"step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n",
"step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n",
"step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n",
"step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n",
"step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n",
"step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n",
"step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n",
"step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n",
"step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n",
"step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n",
"step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n",
"step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n",
"step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n",
"step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n",
"step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n",
"step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n",
"step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n",
"step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n",
"step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n",
"step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n",
"step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n",
"step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n",
"step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n",
"step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n",
"step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n",
"step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n",
"step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n",
"step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n",
"step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n",
"step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n",
"step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n",
"step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n",
"step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n",
"step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n",
"step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n",
"step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n",
"step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n",
"step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n",
"step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n",
"step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n",
"step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n",
"step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n",
"step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n",
"step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n",
"step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n",
"step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n",
"step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n",
"step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n",
"step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n",
"step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n",
"step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n",
"step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n",
"step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n",
"step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n",
"step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n",
"step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n",
"step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n",
"step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n",
"step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n",
"step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n",
"step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n",
"step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n",
"step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n",
"step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n",
"step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n",
"step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n",
"step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n",
"step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n",
"step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n",
"step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n",
"step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n",
"step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n",
"step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n",
"step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n",
"step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n",
"step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n",
"step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n",
"step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n",
"step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n",
"step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n",
"step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n",
"step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n",
"step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n",
"step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n",
"step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n",
"step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n",
"step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n",
"step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n",
"step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n",
"step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n",
"step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n",
"step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n",
"step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n",
"step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n",
"step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n",
"step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n",
"step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n",
"step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n",
"step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n",
"step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n",
"Eval samples: 10000\n",
"Epoch 2/2\n",
"step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n",
"step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n",
"step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n",
"step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n",
"step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n",
"step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n",
"step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n",
"step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n",
"step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n",
"step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n",
"step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n",
"step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n",
"step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n",
"step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n",
"step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n",
"step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n",
"step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n",
"step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n",
"step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n",
"step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n",
"step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n",
"step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n",
"step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n",
"step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n",
"step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n",
"step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n",
"step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n",
"step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n",
"step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n",
"step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n",
"step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n",
"step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n",
"step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n",
"step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n",
"step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n",
"step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n",
"step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n",
"step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n",
"step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n",
"step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n",
"step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n",
"step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n",
"step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n",
"step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n",
"step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n",
"step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n",
"step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n",
"step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n",
"step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n",
"step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n",
"step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n",
"step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n",
"step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n",
"step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n",
"step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n",
"step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n",
"step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n",
"step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n",
"step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n",
"step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n",
"step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n",
"step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n",
"step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n",
"step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n",
"step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n",
"step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n",
"step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n",
"step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n",
"step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n",
"step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n",
"step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n",
"step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n",
"step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n",
"step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n",
"step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n",
"step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n",
"step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n",
"step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n",
"step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n",
"step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n",
"step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n",
"step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n",
"step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n",
"step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n",
"step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n",
"step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n",
"step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n",
"step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n",
"step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n",
"step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n",
"step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n",
"Eval begin...\n",
"step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n",
"step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n",
"step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n",
"step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n",
"step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n",
"step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n",
"step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n",
"step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n",
"step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n",
"step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n",
"step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n",
"step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n",
"step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n",
"Eval samples: 10000\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n"
]
}
],
"source": [
"model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=2,\n",
" batch_size=64,\n",
" save_dir='mnist_checkpoint')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练方式1结束\n",
"以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.训练方式2\n",
"方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n",
"epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n",
"epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n",
"epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n",
"epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n",
"epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n",
"epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n",
"epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n",
"epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n",
"epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n",
"epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n",
"epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n",
"epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n",
"epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n",
"epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n",
"epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n",
"epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n",
"epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n",
"epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n",
"epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n"
]
}
],
"source": [
"import paddle\n",
"train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def train(model):\n",
" model.train()\n",
" epochs = 2\n",
" batch_size = 64\n",
" optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = LeNet()\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对模型进行验证"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_id: 0, loss is: [1.5017498], acc is: [1.]\n",
"batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n",
"batch_id: 200, loss is: [1.4958509], acc is: [1.]\n",
"batch_id: 300, loss is: [1.4924574], acc is: [1.]\n",
"batch_id: 400, loss is: [1.4762049], acc is: [1.]\n",
"batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n",
"batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n",
"batch_id: 700, loss is: [1.4954926], acc is: [1.]\n",
"batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n",
"batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n"
]
}
],
"source": [
"import paddle\n",
"test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n",
"def test(model):\n",
" model.eval()\n",
" batch_size = 64\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n",
" acc = paddle.metric.accuracy(predicts, y_data, k=2)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_acc = paddle.mean(acc)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n",
"test(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 训练方式2结束\n",
"以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 总结\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 用N-Gram模型在莎士比亚文集中训练word embedding\n",
"N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n",
"N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n",
"本示例在莎士比亚文集上实现了trigram。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.0.0'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import paddle\n",
"paddle.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据集&&相关参数\n",
"训练数据集采用了莎士比亚文集,[下载](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt),保存为txt格式即可。<br>\n",
"context_size设为2,意味着是trigram。embedding_dim设为256。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2020-09-03 08:41:10-- https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\n",
"正在解析主机 ocw.mit.edu (ocw.mit.edu)... 151.101.230.133\n",
"正在连接 ocw.mit.edu (ocw.mit.edu)|151.101.230.133|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度:5458199 (5.2M) [text/plain]\n",
"正在保存至: “t8.shakespeare.txt.1”\n",
"\n",
"t8.shakespeare.txt. 100%[===================>] 5.21M 26.1KB/s 用时 4m 14s \n",
"\n",
"2020-09-03 08:45:25 (21.0 KB/s) - 已保存 “t8.shakespeare.txt.1” [5458199/5458199])\n",
"\n"
]
}
],
"source": [
"!wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"embedding_dim = 256\n",
"context_size = 2"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 5458199 characters\n"
]
}
],
"source": [
"# 文件路径\n",
"path_to_file = './t8.shakespeare.txt'\n",
"test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
"\n",
"# 文本长度是指文本中的字符个数\n",
"print ('Length of text: {} characters'.format(len(test_sentence)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 去除标点符号\n",
"用`string`库中的punctuation,完成英文符号的替换。"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n"
]
}
],
"source": [
"from string import punctuation\n",
"process_dicts={i:'' for i in punctuation}\n",
"print(process_dicts)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"28343\n"
]
}
],
"source": [
"punc_table = str.maketrans(process_dicts)\n",
"test_sentence = test_sentence.translate(punc_table)\n",
"test_sentence = test_sentence.lower().split()\n",
"vocab = set(test_sentence)\n",
"print(len(vocab))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据预处理\n",
"将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[['this', 'is'], 'the'], [['is', 'the'], '100th'], [['the', '100th'], 'etext']]\n"
]
}
],
"source": [
"trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n",
" for i in range(len(test_sentence) - 2)]\n",
"\n",
"word_to_idx = {word: i for i, word in enumerate(vocab)}\n",
"idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n",
"# 看一下数据集\n",
"print(trigram[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 构建`Dataset`类 加载数据\n",
"用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"batch_size = 256\n",
"paddle.disable_static()\n",
"class TrainDataset(paddle.io.Dataset):\n",
" def __init__(self, tuple_data):\n",
" self.tuple_data = tuple_data\n",
"\n",
" def __getitem__(self, idx):\n",
" data = self.tuple_data[idx][0]\n",
" label = self.tuple_data[idx][1]\n",
" data = np.array(list(map(lambda w: word_to_idx[w], data)))\n",
" label = np.array(word_to_idx[label])\n",
" return data, label\n",
" \n",
" def __len__(self):\n",
" return len(self.tuple_data)\n",
"train_dataset = TrainDataset(trigram)\n",
"train_loader = paddle.io.DataLoader(train_dataset,places=paddle.fluid.cpu_places(), return_list=True,\n",
" shuffle=True, batch_size=batch_size, drop_last=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 组网&训练\n",
"这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import numpy as np\n",
"hidden_size = 1024\n",
"class NGramModel(paddle.nn.Layer):\n",
" def __init__(self, vocab_size, embedding_dim, context_size):\n",
" super(NGramModel, self).__init__()\n",
" self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n",
" self.linear1 = paddle.nn.Linear(context_size * embedding_dim, hidden_size)\n",
" self.linear2 = paddle.nn.Linear(hidden_size, len(vocab))\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x = paddle.reshape(x, [-1, context_size * embedding_dim])\n",
" x = self.linear1(x)\n",
" x = paddle.nn.functional.relu(x)\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 定义`train()`函数,对模型进行训练。"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch_id: 0, loss is: [10.252256]\n",
"epoch: 0, batch_id: 100, loss is: [7.0485706]\n",
"epoch: 0, batch_id: 200, loss is: [7.282592]\n",
"epoch: 0, batch_id: 300, loss is: [6.9604626]\n",
"epoch: 0, batch_id: 400, loss is: [6.7308316]\n",
"epoch: 0, batch_id: 500, loss is: [6.7940483]\n",
"epoch: 0, batch_id: 600, loss is: [6.6574802]\n",
"epoch: 0, batch_id: 700, loss is: [6.862562]\n",
"epoch: 0, batch_id: 800, loss is: [7.2091002]\n",
"epoch: 0, batch_id: 900, loss is: [7.0172606]\n",
"epoch: 0, batch_id: 1000, loss is: [6.9888105]\n",
"epoch: 0, batch_id: 1100, loss is: [6.9609995]\n",
"epoch: 0, batch_id: 1200, loss is: [6.550024]\n",
"epoch: 0, batch_id: 1300, loss is: [6.714109]\n",
"epoch: 0, batch_id: 1400, loss is: [6.995716]\n",
"epoch: 0, batch_id: 1500, loss is: [6.939434]\n",
"epoch: 0, batch_id: 1600, loss is: [6.5966253]\n",
"epoch: 0, batch_id: 1700, loss is: [6.9880104]\n",
"epoch: 0, batch_id: 1800, loss is: [6.6459093]\n",
"epoch: 0, batch_id: 1900, loss is: [6.8095036]\n",
"epoch: 0, batch_id: 2000, loss is: [6.8447037]\n",
"epoch: 0, batch_id: 2100, loss is: [6.8313]\n",
"epoch: 0, batch_id: 2200, loss is: [6.808483]\n",
"epoch: 0, batch_id: 2300, loss is: [6.502908]\n",
"epoch: 0, batch_id: 2400, loss is: [6.561283]\n",
"epoch: 0, batch_id: 2500, loss is: [7.0093765]\n",
"epoch: 0, batch_id: 2600, loss is: [6.512396]\n",
"epoch: 0, batch_id: 2700, loss is: [6.809763]\n",
"epoch: 0, batch_id: 2800, loss is: [6.806659]\n",
"epoch: 0, batch_id: 2900, loss is: [6.95402]\n",
"epoch: 0, batch_id: 3000, loss is: [6.634927]\n",
"epoch: 0, batch_id: 3100, loss is: [6.644098]\n",
"epoch: 0, batch_id: 3200, loss is: [6.705504]\n",
"epoch: 0, batch_id: 3300, loss is: [6.2121572]\n",
"epoch: 0, batch_id: 3400, loss is: [6.638401]\n",
"epoch: 0, batch_id: 3500, loss is: [6.986831]\n",
"epoch: 1, batch_id: 0, loss is: [6.795429]\n",
"epoch: 1, batch_id: 100, loss is: [6.582568]\n",
"epoch: 1, batch_id: 200, loss is: [6.527663]\n",
"epoch: 1, batch_id: 300, loss is: [6.714637]\n",
"epoch: 1, batch_id: 400, loss is: [6.574902]\n",
"epoch: 1, batch_id: 500, loss is: [6.305031]\n",
"epoch: 1, batch_id: 600, loss is: [6.803609]\n",
"epoch: 1, batch_id: 700, loss is: [6.2429113]\n",
"epoch: 1, batch_id: 800, loss is: [6.7452283]\n",
"epoch: 1, batch_id: 900, loss is: [6.383783]\n",
"epoch: 1, batch_id: 1000, loss is: [6.4906135]\n",
"epoch: 1, batch_id: 1100, loss is: [6.6007314]\n",
"epoch: 1, batch_id: 1200, loss is: [6.63466]\n",
"epoch: 1, batch_id: 1300, loss is: [6.540749]\n",
"epoch: 1, batch_id: 1400, loss is: [6.7752547]\n",
"epoch: 1, batch_id: 1500, loss is: [6.2411666]\n",
"epoch: 1, batch_id: 1600, loss is: [6.540929]\n",
"epoch: 1, batch_id: 1700, loss is: [6.6563463]\n",
"epoch: 1, batch_id: 1800, loss is: [6.4592104]\n",
"epoch: 1, batch_id: 1900, loss is: [7.0268345]\n",
"epoch: 1, batch_id: 2000, loss is: [6.803793]\n",
"epoch: 1, batch_id: 2100, loss is: [6.8454733]\n",
"epoch: 1, batch_id: 2200, loss is: [6.651756]\n",
"epoch: 1, batch_id: 2300, loss is: [6.5876465]\n",
"epoch: 1, batch_id: 2400, loss is: [6.258934]\n",
"epoch: 1, batch_id: 2500, loss is: [6.5422425]\n",
"epoch: 1, batch_id: 2600, loss is: [6.184501]\n",
"epoch: 1, batch_id: 2700, loss is: [6.6847773]\n",
"epoch: 1, batch_id: 2800, loss is: [6.684101]\n",
"epoch: 1, batch_id: 2900, loss is: [6.374978]\n",
"epoch: 1, batch_id: 3000, loss is: [6.8277273]\n",
"epoch: 1, batch_id: 3100, loss is: [6.5195084]\n",
"epoch: 1, batch_id: 3200, loss is: [6.311832]\n",
"epoch: 1, batch_id: 3300, loss is: [6.4282994]\n",
"epoch: 1, batch_id: 3400, loss is: [6.603338]\n",
"epoch: 1, batch_id: 3500, loss is: [6.4541807]\n"
]
}
],
"source": [
"import time\n",
"vocab_size = len(vocab)\n",
"epochs = 2\n",
"losses = []\n",
"def train(model):\n",
" model.train()\n",
" optim = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\n",
" for epoch in range(epochs):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" x_data = data[0]\n",
" y_data = data[1]\n",
" predicts = model(x_data)\n",
" y_data = paddle.nn.functional.one_hot(y_data, len(vocab))\n",
" loss = paddle.nn.functional.softmax_with_cross_entropy(predicts, y_data,soft_label=True)\n",
" avg_loss = paddle.mean(loss)\n",
" avg_loss.backward()\n",
" if batch_id % 100 == 0:\n",
" losses.append(avg_loss.numpy())\n",
" print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy())) \n",
" optim.minimize(avg_loss)\n",
" model.clear_gradients()\n",
"model = NGramModel(vocab_size, embedding_dim, context_size)\n",
"train(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 打印loss下降曲线\n",
"通过可视化loss的曲线,可以看到模型训练的效果。"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x166d69048>]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXzcVbn48c8zM1mapVknSdukTbM03egaSlvaUihbkQuigOBVEAVE8bpdrxdc4Kf+firqVVG8cMviBQVFkE1k3wtdaLo3NGmbdEnSZm32NOuc3x+zdGYyaZaZNGH6vF+vvDrznW/me5pknjnznOecI8YYlFJKhS/LWDdAKaXU6NJAr5RSYU4DvVJKhTkN9EopFeY00CulVJizjXUD/KWmpprs7OyxboZSSn2sbN26td4YYw/02LgL9NnZ2RQVFY11M5RS6mNFRA4P9JimbpRSKsxpoFdKqTCngV4ppcKcBnqllApzGuiVUirMaaBXSqkwp4FeKaXCXNgE+mPNJ/j1a6WU17WNdVOUUmpcCZtAX9faxe/eOkB5XftYN0UppcaVsAn0kTbnf6W7zzHGLVFKqfElfAK91RXoezXQK6WUt/AJ9DYN9EopFUjYBfouTd0opZSPsAn0UVYroD16pZTyFzaB3tOj7+0b45YopdT4EnaBXnv0SinlK2wCvdUiWC2igV4ppfwMGuhF5BERqRWRPV7HkkXkdRHZ7/o3aYDvvdF1zn4RuTGUDQ8k0mrRQK+UUn6G0qP/X+BSv2N3AG8aY/KBN133fYhIMnA3cA6wBLh7oDeEUIm0WXTClFJK+Rk00Btj3gOO+x2+EnjUdftR4JMBvvUS4HVjzHFjTCPwOv3fMEIqyqY9eqWU8jfSHH26MeaY63Y1kB7gnClAhdf9StexUROpgV4ppfoJejDWGGMAE8xziMitIlIkIkV1dXUjfp5Im0UnTCmllJ+RBvoaEZkE4Pq3NsA5VUCW1/1M17F+jDHrjDGFxphCu90+wibpYKxSSgUy0kD/AuCuorkReD7AOa8CF4tIkmsQ9mLXsVGjOXqllOpvKOWVfwE2AgUiUikiXwJ+DlwkIvuBC133EZFCEXkIwBhzHPgJsMX19WPXsVGjOXqllOrPNtgJxpjrB3hoTYBzi4Cbve4/Ajwy4tYNU6TNQmePBnqllPIWNjNjQXP0SikVSHgFek3dKKVUP2EW6K06M1YppfyEV6DX1I1SSvUTXoHeZqFLA71SSvkIq0DvrKPXjUeUUspbWAV6Xb1SKaX6C69A78rRO5ffUUopBeEW6G0WHAZ6HRrolVLKLewCPei+sUop5S28Ar1VA71SSvkLr0Dv7tHrgKxSSnmEZ6DXHr1SSnmEVaCPcgV6nTSllFInhWWg1x69UkqdFFaBXnP0SinVX1CBXkS+ISJ7RKRYRL4Z4PHVItIsIjtcX3cFc73BRFqtgPbolVLK26A7TA1EROYCtwBLgG7gFRF50RhzwO/U9caYy4No45DpYKxSSvUXTI9+FrDZGNNhjOkF3gU+FZpmjczJ1I0ubKaUUm7BBPo9wEoRSRGRGOAyICvAectEZKeIvCwicwI9kYjcKiJFIlJUV1c34gbphCmllOpvxKkbY8xeEbkHeA1oB3YA/l3pbcA0Y0ybiFwGPAfkB3iudcA6gMLCwhEvVBOp5ZVKKdVPUIOxxpiHjTGLjTGrgEZgn9/jLcaYNtftl4AIEUkN5pqnouWVSinVX7BVN2muf6fizM8/4fd4hoiI6/YS1/UagrnmqWh5pVJK9Tfi1I3L30UkBegBbjfGNInIbQDGmAeAq4GviEgvcAK4zoziYvGao1dKqf6CCvTGmJUBjj3gdfs+4L5grjEcWl6plFL9hefMWA30SinlEVaB3mYRRDRHr5RS3sIq0IuIZ99YpZRSTmEV6MGZvtE6eqWUOinsAn2UBnqllPIRdoFeUzdKKeUr/AK9zaKDsUop5SXsAn2UzUp3r65eqZRSbmEX6CNtmrpRSilv4RnoNXWjlFIe4RfodTBWKaV8hF+g19SNUkr5CMtAr3X0Sil1UlgGes3RK6XUSWEX6KM0R6+UUj6C3WHqGyKyR0SKReSbAR4XEfmdiBwQkV0isiiY6w2F5uiVUsrXiAO9iMwFbgGWAPOBy0Ukz++0tTg3A88HbgXuH+n1hkpTN0op5SuYHv0sYLMxpsMY0wu8i3PfWG9XAo8Zp01AoohMCuKag9LySqWU8hVMoN8DrBSRFBGJAS4DsvzOmQJUeN2vdB3zISK3ikiRiBTV1dUF0SRN3SillL8RB3pjzF7gHuA14BVgBzCiRWaMMeuMMYXGmEK73T7SJgHOQN/rMDgco7YHuVJKfawENRhrjHnYGLPYGLMKaAT2+Z1ShW8vP9N1bNR49o3VPL1SSgHBV92kuf6dijM//4TfKS8AN7iqb5YCzcaYY8FcczCRVud/SSdNKaWUky3I7/+7iKQAPcDtxpgmEbkNwBjzAPASztz9AaADuCnI6w0qyt2j10CvlFJAkIHeGLMywLEHvG4b4PZgrjFc7tRNl65Jr5RSQBjOjI3UHr1SSvkIv0BvtQI6GKuUUm7hF+i1R6+UUj7CLtDrYKxSSvkKu0CvPXqllPIVtoG+S3P0SikFhGOgt2qPXimlvIVdoNccvVJK+Qq7QK85eqWU8hW+gV5z9EopBYRjoNccvVJK+Qi/QK+pG6WU8hG+gV5TN0opBYRjoNf16JVSykfYBXoR0Q3ClVLKS7A7TH1LRIpFZI+I/EVEov0e/4KI1InIDtfXzcE1d2h0g3CllDppxIFeRKYAXwcKjTFzAStwXYBTnzTGLHB9PTTS6w1HpM1Cd59uPKKUUhB86sYGTBARGxADHA2+ScHT1I1SSp004kBvjKkCfgUcAY7h3Pj7tQCnflpEdonI0yKSFei5RORWESkSkaK6urqRNslDUzdKKXVSMKmbJOBKYDowGYgVkc/5nfYPINsYMw94HXg00HMZY9YZYwqNMYV2u32kTfKItFm06kYppVyCSd1cCBw0xtQZY3qAZ4Dl3icYYxqMMV2uuw8Bi4O43pBp6kYppU4KJtAfAZaKSIyICLAG2Ot9gohM8rp7hf/jo8U5GKuBXimlwDmYOiLGmM0i8jSwDegFtgPrROTHQJEx5gXg6yJyhevx48AXgm/y4DR1o5RSJ4040AMYY+4G7vY7fJfX43cCdwZzjZGIsllo7ew93ZdVSqlxKexmxoIz0GuOXimlnMIy0GuOXimlTgrPQK9VN0op5RGegV5TN0op5RG+gV5TN0opBYRroLdatUevlFIu4RnoNXWjlFIe4Rvo+xwYY8a6KUopNebCMtBH6b6xSinlEZaB3r1vrKZvlFIqXAO9TQO9Ukq5hXeg19SNUkqFaaDX1I1SSnmEZ6DX1I1SSnmEdaDXNemVUkoDvVJKhb2gAr2IfEtEikVkj4j8RUSi/R6PEpEnReSAiGwWkexgrjdUUZqjV0opjxEHehGZAnwdKDTGzAWswHV+p30JaDTG5AG/Ae4Z6fWGQ6tulFLqpGBTNzZggojYgBjgqN/jVwKPum4/DaxxbSQ+qqJsVkB79EopBUEEemNMFfAr4AhwDGg2xrzmd9oUoMJ1fi/QDKT4P5eI3CoiRSJSVFdXN9ImeWjVjVJKnRRM6iYJZ499OjAZiBWRz43kuYwx64wxhcaYQrvdPtImeZxM3fQF/VxKKfVxF0zq5kLgoDGmzhjTAzwDLPc7pwrIAnCldxKAhiCuOSTao1dKqZOCCfRHgKUiEuPKu68B9vqd8wJwo+v21cBb5jSsHawzY5VS6qRgcvSbcQ6wbgN2u55rnYj8WESucJ32MJAiIgeAbwN3BNneIdE6eqWUOskWzDcbY+4G7vY7fJfX453ANcFcYyR0PXqllDopPGfGaupGKaU8wjLQWyyCzSIa6JVSijAN9KAbhCullFt4B3rN0SulVBgHeqv26JVSCsI50GvqRimlgDAP9F2aulFKqTAO9Jq6UUopIIwDfZSmbpRSCgjjQB9ps9DVq6tXKqVUWAd67dErpVQ4B3qr1tErpRSEc6DXHr1SSgFhHOijbFYN9EopRRgHeu3RK6WUUzB7xhaIyA6vrxYR+abfOatFpNnrnLsGer5Q07VulFLKacQbjxhjSoEFACJixbk/7LMBTl1vjLl8pNcZqUirRXeYUkopQpe6WQOUGWMOh+j5gqYTppRSyilUgf464C8DPLZMRHaKyMsiMifQCSJyq4gUiUhRXV1dSBrkTt2chr3IlVJqXAs60ItIJHAF8FSAh7cB04wx84HfA88Feg5jzDpjTKExptButwfbJMCZujEGeh0a6JVSZ7ZQ9OjXAtuMMTX+DxhjWowxba7bLwERIpIagmsOKtKm+8YqpRSEJtBfzwBpGxHJEBFx3V7iul5DCK45KA30SinlNOKqGwARiQUuAr7sdew2AGPMA8DVwFdEpBc4AVxnTlPS3BPotcRSKXWGCyrQG2PagRS/Yw943b4PuC+Ya4xUpFV79EopBWE+MxbQWnql1BkvbAN9lObolVIKCONArzl6pZRyCt9Ab7UCvj36zp4+mjt6xqpJSik1JsI30AdI3dzx911c9d8f6GxZpdQZJewDvXvf2Mb2bl7aXU15fTvl9e1j2TSllDqtwjfQ+5VXvrDzqCdfv35faNbTUUqpj4PwDfR+g7FPb61k9qSJZKfEsH5//ahe+/HNh9l6+PioXkMppYYqbAN9lFcdfUl1C7urmrmmMJOV+XY2ljeMWtllW1cvdz1fzB1/341DF1RTSo0DYR/ou3sdPF1USYRVuHLBFFbmp9LR3ce2I42jct0tB4/T5zDsr23jnX21o3INpZQajrAN9O7UzYnuPp7bUcWamekkx0ayLDcFq0VYv3908vQbyuqJtFqYlBDNA++Wj8o1lFJqOMI+0L/2UTX1bd1cvTgTgPjoCBZNTRy1PP3G8gYWTE3k5pU5fHjwONtH6ZODUkoNVfgGelfVzZZDjaTGRXFewckNTVbm29ld1czx9u6QXrO5o4fioy0sy0nhurOzSJgQwbr3tFevlBpbYRvobVYLFnHevmrhZCKsJ/+rK/NTMQbePxDaXv3mgw0YA8tzU4iNsvG5pVN5pbiag1q3r5QaQ2Eb6OFk+uaawiyf4/MyE0mYEBHyevqN5Q1E2SwsmJoIwI3Ls4mwWnhwvfbqlVJjZ8SBXkQKRGSH11eLiHzT7xwRkd+JyAER2SUii4Jv8tBFWi3Mz0xgRnq8z3GrRViRl8r6/fUhXQ5hY1kDhdlJRNmc6+ykxUfz6UVTeHprJXWtXSG7jlJKDceIA70xptQYs8AYswBYDHQAz/qdthbId33dCtw/0uuNxDcvnMH3LpsV8LGV+alUt3RyoLYtJNc63t5NSXUry3J89mHh5pU59PQ5WPdemdbVK+VSfLSZ7zy1U5cRP01ClbpZA5QZYw77Hb8SeMw4bQISRWRSiK45qC+umM45foHXbUW+c4/y90JUfbOp3LkV7rJc3+vl2uO4bO4kHlx/kOU/f4u7n9/DhrJ6ekOwfPLB+nb6PoZvHj19Dq5bt5F3SnWewZnqn7uO8fTWSt7YWzPWTTkjhCrQX0fgDcKnABVe9ytdx3yIyK0iUiQiRXV1p2cdmsykGHLssSGrp99Y1kBMpJV5mYn9Hvuva+fzm8/MZ35WAn/dUsFnH9zMkp++yXef3skbH9XQ2dM37OttP9LI+b96h0/dv4GS6pYhfc++mlaueWADjSGuNhqufTWtbCo/zmsfjf8XefOJHm585EP21bSOdVPCSlmd85P0Xz48MsYtOTMEHehFJBK4AnhqpM9hjFlnjCk0xhTa7fbBvyFEVuXb2VTeMKJA629jeQNnZyf7VPe4RUdYuWphJv/z+UK233UR9//rIlbkpfLy7mpufqyIRT95ndsf30Zta+eQr/fynmoirELl8Q4u/937/PLVkkH/Hy/tPsaWQ428N0qTxYaquMr5xlRybGhvUAOpa+2io7s3FE0a0Hv76nh3Xx2/fWPfqF7nTFNe144IrN9fT8XxjrFuTtgLRY9+LbDNGBOoe1YFeJe8ZLqOjQurC+x09jiCTiHUtjpz/f5pm0BiIm2sPWsSv7t+IVt/eBGPfnEJVy2cwivF1Tzy/qEhXc8Yw6vF1SzPTeWNb5/HlQum8Ie3y1h773r2niJ4bjvSBMCGAw1Dus5o2XO0GYB9NW0jHgw/3t7Nxb95lx+98FEom9bPRldK7uU91RzSMtmQ6O1zcKihnSvmT8Yi8OSWisG/SQUlFIH+egKnbQBeAG5wVd8sBZqNMcdCcM2QWJGXSvrEKP5WVBnU82wqd65U6T8QO5hIm4XzZtj5f1edxZzJE9lRMbRZtPtq2jjc0MHFc9JJio3kv66dz5+/dA5NHd38+vXAPU+Hw7D9sPP5N5SP7uqdg9ld5Qz0bV29VDaeGNFz/PLVEho7enizpHZUN5LZVNbA/KxELZMNoYrGE/T0GVbkpbK6II2ntlaEZMxKDSyoQC8iscBFwDNex24Tkdtcd18CyoEDwIPAV4O5XqjZrBY+vSiTd0prqWkZetrE38ayBuKjbMyZPHHEz7EgK5Hdlc1DGlx9rbgaEbhoVrrn2Ir8VC6Zk8Gm8oaAz7G/to3Wrl7mZyZQcfzEmH1c7u1zsPdYCwtdcw1Kqoef+95R0cRft1QwPTWW+rYu9h4bnfx5TUsn5fXtfOKsDD69KJOntEw2JMpclW65aXFcd3YWNS1dvF2qe0SMpqACvTGm3RiTYoxp9jr2gDHmAddtY4y53RiTa4w5yxhTFGyDQ+3awiwcxrle/UhtKm9gyfRkbAHy80O1ICuR9u4+9tcOHrRe+6iGhVmJpE2M9jm+LDeF1s5ePjraP32z1dWbv/38PMD55jQWyura6exx8KlFzrWHSoc4kOzW5zDc9fweUuOieOjGQoBRG3PwVFLlpHLLyun09Dl4bOOhUbnWmcQ9EJubGscFM9NIi4/SQdlRFtYzY4ciOzWWJdOTeaqoYkQpgKNNJzhY3z6k/PypLMhy9nB3uPLop7re7qpmLp6T0e8xdxs2lPVPzWw70khybCQXzkonNS4q4Dmnwx5X2mbp9GSykiewd5g9+ie3VLCrspnvXzaLXHscMzPieW+UdgzbWNbAxGgbsydPJMcex8Wz03ls42Hau0Z3ADjclde1kxoXSUJMBDarhWsKnZ+qjzWPLI2nBnfGB3qAzxRmcaihgy2Hhr/S5N+KKhCBC73SKCMxPTWWhAkR7Kg4daB/3VWSePHs/tdLi48mPy2ODQF669sON7JoahIWi7A8N4UNZQ2nfGNzOAzv76/n9se3seKet7h+3Sa+/+xuHlpfztultTSf6Bnm/9Bpd1UzEyKs5NjjKEifSOkwAn1jeze/eLWEJdOTuXLBZABWzbBTdKhxVKpvNpY3sGS6c1lrgC+fl0vziR4dPAxSWV0bOfY4z/3PFE7FYeBvW5yfqrt7HfxtSwVr713PbwYYcxprH7cxBdtYN2A8WHtWBne/UMyTWypYMj15yN/X3evgz5uOcH5BGtmpsUG1QUSYn5U4aKB/tbiavLQ4nxeKt+W5KTy1tZLuXodnrZ/j7d2U17dzdWGm55wXdh6lrK6NvDTf5SEa27v5y5Yj/PXDCo4c7yAxJoJzc1M52nyCF3cd8wR4i8DcKQksy03h3NxUluemDCl1VXy0mdmTJ2K1CLMmxfN2aS2dPX1ER1h9zuvpc1ByrBWbVYiOsBIdYeG3r++ntbOXn1w5FxFn8F2Vb2fde+VsKm/ggpnBvdl6O9p0gsMNHdywLNtzbNHUJJZkJ/Pw+wf5/LJpAUtp1eDK6tq4dO7JeZNTU2JYkZfK34oqiIu28dD6co41dxIbaeV/3ivjxuXZJMdGjmGLfe2ubObTD2zgH19bQUFG/ODfMA7oXyrOksd/mT+Jl3Yfo7XTt6f64cHjnnSDv5d2H6O+rYsvLM8OSTsWZCWyr6Z1wNRAU0c3mw8eD9ibd1uW69xBa1flyTcM95r4i6cmAbA81zkr2L/n3+cwXP/gJn7xSimTEqK597oFbLpzDX/410U8+9Vz2XHXRWz9wYU8ccs5fO2CfKJsFh5ef5AbHvmQn71cMuj/z+EwFB9tYa5r0LogI54+hwm4DMWD68v5l/veZ+296zn/V++w7Gdv8WRRBTcuy/Z5cRVmJxEdYeG9faFNRbnHMJbm+L7xf/m8HKqaTvDirqMhvZ6/po5urrjv/VOWy34cHW/vprGjh1y7b8fo+iVTqWo6wU9e/IispBj+eNPZPHv7uXT2OPjTRv8J92NrQ1k93b0OXt4zbgoIB6U9epdrC7P4y4cV/HPXMa5bMpXOnj5+9tJeHt14mKSYCN749nmkxEX5fM8fNxwi1x7LStdyCsFamJWIw8CuyuaAOf+3SmrpcxguCZCfd1uak4yIM4gXZjuD1NbDjdgs4pm1m5U8gSmJE9hwoMGnx/r3bZWUVLdy73ULuHJBvwnMiAgpcVEsj4tyvllcNIP2rl6+9sQ2Xtx1lB98Ypanpx1IeX07Hd19zJ2SAMDMDGfAL61u9Rxze7W4hpkZ8XxjTT6dvX109jiIsFq4fJ7vChrREVbOmZ4S8gHZjeUNJMZEMCvDt5Lq/II08tLi+OMHh7hqYWZIr+lt25FGdlU2s35/HbMmjbyaKxRe/6iGyYnRzJmcMPjJg/AMxKb5fiK9ZE46d6ydSeG0JM/fLcD5BXYe23iIL5+X0+9T31jZ4yp2eKe0jm9eOGOMWzM02qN3WZCVSH5aHE8WVVBa3cqV933AoxsPc83iTNq6evnRP3wn5mw/0sjOiia+sDz7lMFtOOa7B2QHSN+8VlxDxsRozpoy8AsuMSaS2ZMm+gy2bj3cyJzJE5kQ6XyhiAjn5qWwsbzBs9BaZ08fv3l9H/OzErli/uQhtzk2ysbauZOoaRm8zLHYNVHKHdSzU2KItFn6LeFQ19rFzoomLp83ibVnTeKqhZlcv2QqVy/ODPhiXzXDTnldO5WNoSsZ3VTewDnTk7FYfH+3Fotww7Jp7KpsZucgabZguMtOT7XonjFm1BfKaz7Rw+1PbOM3r+8PyfOVe1XceLNZLdx2Xq5PkAe4dVUuDe3d/H1bcHNdQqnY9Ql/Z2VTyDcvGi0a6F1EhGsLs9h+pIl/+f37NLR38783nc0vr5nP7efn8cLOo7zptQDTHz84RHyUzVMmGArJsZFMS4kJOHGqs6ePd/fVcdHs9H7Bx9/y3BS2HWmis6ePnj4HOyubWOhK25w8J5XmEz185EoNPLrhEMeaO7nj0pnDfuNy79719iAzjHdXNhNps5Dn6s3ZrBZmpMf1q6V3P8/5M9OGdv0ZrgXqQpS+qTjeQWXjiQEnwF21cAoxkVb+tGn0UgolrjfNsrqBZ+P+/q0DrL13/ahOGHtx11G6ex0cGELZ71CU1bUTabMwJWnCkM5fmpPMvMwEHlp/cFys/trW1cvBhnbWzEzDGEZt7+lQ00Dv5apFU5gYbePcvBRe+eZKVhc4A81XV+cxIz2OHzy3h9bOHmpaOnlp9zGuPTuL2KjQZr8WDDAgu35/PSd6+rh4zuADjstzU+nudbDtcCMlx1rp7HGweJpvoHenhjaWNdDc0cMf3j7A6gL7iMpE0ydGM2fyxEGXkthztJlZkyb6DGIWpE/sF+jf2ltLxsRoZg8xZZFrj2NyQvSQyixPdPcNGjA2elYiDZySi4+O4KqFU/jHzqNDWiDOGMOfNh0e1k5jpV49+oEC+YayekprWjkyipPfnnLNGj98vCMka0KV1baRkxrrqWQajIhwy8ocDta38/o4WOly77EWjHGOKaTERvJ2ycdjBVYN9F5S46LY8oML+eNNS0j1ysdH2izc8+l5VLd0cs8rJTy+6TB9xnDDsmkhb8OCrERqWrp8aoqNMTz4XjmpcVGcM33wQHz29GSsFmFDWQNbDzuXZ/AP9OkTo8m1x7KhrJ7/fvcArV29fPeSmSNu9/kFaWw70kRzR+CyS4fDUFx1ciDWbWZGPHWtXZ6PwF29fazfX8cFs9KG/MlCRFiZb+eDUyz/3Ocw3P9OGfN+9Cp/HaQ8clNZA8mxkeSnBa5sAvjc0ml09TqGNNHugwMN/PC5PXzlz1vpGUJZXnevg7K6NuKjbTSf6AmYHjDGeN4MNh88PuhzjsSB2lZ2VDSxcGoixpzMrwejrK6N3AEqxgaydm4GmUkTeHCU9l+ub+viUH07Rxo6qDjeQVXTiQFnqLsLM+ZlJrBqhp339tePi08ag9FA78e9O5S/hVOTuGn5dP686QiPfHCINTPTmJYSXEllIO6JU97537dLa/nw0HG+sSbPUzJ5KnFRNuZnJrChrJ6tR5rImBjN5MT+H5WX56ayqfw4//vBIa5aMIXZQSzhcP5MO30Ow/oDgXvVR4530NrV22/QdeYkZwWNO0//4cHjtHf3sWaIaRu3VTPstHb2srOy/6ehIw0dXLduI/e8UkKvw7D54MCzgo0xbCxvYGlO//y8t1mTJnJ2dhJ/3nx40Bf6f79zgAkRVkqqW3lo/cFB/y9ldW30OgwXz3YOugfK09e1dtHoelPdXD46gf7prVVYLcJ/XFwwYDuGo6u3jyPHO8ixD+91Y7NauHnFdIoON3pmeIdKdXMnS3/6Jqt/9Q6rfvk2K3/xNuf+/C1+8NzugOfvqWrBHh9F2sRoVhfYOd7eza4BqvLGEw30w/CdS2aQmTSBtq5evrB8+qhcY/bkiURaLWx3Bfo+h+EXr5SSnRLDdUumDvl5luemsrOymY1lDf168yfPSeFETx/GwLcuCq56YEFWEokxEbxdEjjQu1es9B9IdpdKunPSb+6tJcpm8ZSADtWKvFQsAu+68vTGGI63d/PXD4+w9t73KDnWym8+M5/VM+yeawVyuKGDY82dQ1qg7nNLp3G4oYP1p9hkfkdFExvKGvjWRflcMiede9/cx+GGU6dw3D31T8xzBvpAeXp3uis1LooPD4V+OYvePgfPbKvk/AI7hdnOT4j7a4YW6Dt7+vjes7spOuT7BnSkoQOHYdg9enDu+5wwIYL73tof0s12dlU20eswfLhCvFcAABhRSURBVOfiGfzXNfP55dXzWJaTwqvFNQHfwIuPNns+la7MtyPCx2IDHQ30wxATaeOBzy3m2xfN4Ny84JY8GEiUzcqsyRM9SyE8t72KkupWvnNJwbAm6CzPTaHPYahv62LRAIF+aU4KkTYLNyybRlZyTFDttlqEVfl23t1XG/AFsruqmQirkJ/u+yK3x0WREhtJaXUrxhjeKqnl3LxUT4XQUCXERDA/K5E/fnCQFfe8RcEPXmHRT17njmd2c1ZmAq98axVXLcxk5qSJlNW1DbiF3UA7hQVy6dwMUuMi+dPGQwOec/87B0iYEMFnz5nGj66Yi81i4QfP7TnlAGpJdSsRVuHcvFSiIywBUybuN4PPLsmi4vgJjjYNf/mAutYu/ufdMu58ZhdtfnM31h+op7a1i6sXZxJps5CdEjPkzVd+/OJHPLH5CD/6x0c+/09PaeUIAn1slI3bzsvl7dI6PvvgppAtl+D+Od507nQ+vTiTawqzuG5JVsCeemdPH/tr2zyfSpNjI1mQlfixWJBNA/0wzZ2SwNfX5IespDKQhVmJ7K5qpqO7l1+/vo+zpiRw2dzh7cC4aFqSJ80zUI8+KTaSN799HnesHXlu3tv5M+3Ut3V7eu/eiqtaKMiI75caExEKMuIpqW6hrK6dI8c7uGCYaRu3W1bmcNaUBAqnJXHTimzuunw2D99YyBM3L2WKK3U1MyOeXocZMN+8/UgTCRMihhSMomxWPnN2Fm+W1AZcDfRAbSuvFtdw47JpxEXZyEiI5ruXFrB+fz3P7Rh4W4aS6hZy7XFE2azkpMYFbGtJdStp8VFcMtfZ6/9wiHn63j4Hb3xUwy2PFbH0Z2/ys5dL+OuWCm59rMhnsPXprZUkxUR4ZhvPSI8fUurm2e2VPLH5CDMz4tld1ezTLvcnk+GmbtxuOy+HX10zn91Vzay9dz2vFVeP6Hm8ldS0MjU5xqeoYqCeekl1K30O47NK7eoZaeyqbKKhbfirmo50GZGR0EA/Ds3PSqCju4+7ni+mqukEd6ydOWhJpb/oCCuLpyYRZbOcsnolKzkmqFU3va1yvUD80zfGGPYcbWbuABNuZmZMZF9Nm2cdn5EG+svOmsQTtyzlt9ct5M61s/jiiumsmeVbjuqefDTQ9os7K5uYl5kw5Dfy65dMRQi8Jd7975QzIcLKF849meb713OmsXBqIj95ce+AFTul1a3MdKW08tLiAgbY0hrnG+fMjInER9tOOe7g1t3r4LMPbebmx4rYfqSJm1dM541vn8evr53PhrIGvvbEdnr6HDR1dPN6cQ1XLpji6Szkp8VxqKGdrt6BK2/21bTyvWf2sGR6Mk9/ZTlJMRE86DUmUVbbxqSE6BFXqokIVy/O5MV/W8GUxAnc+qet3P38nqAGQ0uOtfRbxiA5NpL5mYm849dTdw/Eek8cW11gd5VZDq+0d09VM2f/3zd4+P3Bx2xCQQP9OLQgy9kDf3prJSvzUzk3b2Qzb7910Qx+fOWcIQ3ghkJKXBTzMxP71dNXNp6gqaOHOQNM9JqZEc+Jnj7+vOkwsyZNDDhwHCrTU2OJtFoC5ulPdDs/ms8PsO/vQDKTYlgzK52H3j/IH94+4EkJVTWd4PkdVVy3JMtnnRarRfjZp86i5UQPv3i1/7IRzR09HGvuZKbrDSnXHkdV0wlOdJ8MsH0Ow/6aNmZmxGO1CEuyk4dUefOjfxTz4cHj/OSTc9l45wXcedks8tLiuGphJj+5cg5v7K3hP57ayfM7jtLd5+DqxSfniOSlx+MwDFgi2t7Vy1f+vJXYKBv3Xb+QuCgbn186jTdLajyTpMrq20fcm/eWY4/jma8u54Zl03h04+ERz4ru7OnjUEOH503V2+oCe78JUcVHm0mYEEGm1xyAs6YkkBIbOew8/W/f2Ed3n4Nfv1Ya1F4YQxXsxiOJIvK0iJSIyF4RWeb3+GoRaRaRHa6vu4Jr7pkhOyWGxJgIAP7z0pGnVZZMT+YzZw99ADcU3C8Q90fZ2tZO/s8LxQAsGCCAuntUVU0nhl1tM1wRVueErUDLIxcfdW78Mi9zeFP9f3rVWayZmcYvXy3lE79bz4cHj3tKAW9emdPv/JkZE7l6cSbP7zjarza91JUHd/9MctNiMX4B1tmzdlDgWp5hyfRkyuvaT7nn8F8/PMLjm49w23m5fH5p/wXZPr8sm/+4pIDndhzl//1zL7MmTfSpkHKXmu4LMCBrjOGOZ3ZzsL6d31+/0LNPwueWTSPCYuGRDw5ijKG8dvillQOJsln53mWziLRZTtmbbuvqHbDk90BtG30OE3BhstUF/SdE7alqYe6UiT6f9iwWYdUMO+/uqxvyIPGuyibe2FvLtYWZ9PQZfj6EdaKCFWxX717gFWPMTGA+sDfAOeuNMQtcXz8O8npnBBHhM2dnccvK6f3KEce7810vkHf31fHs9kou+vV7rD9Qzw8+MYuzBgigM9Ljcb92Lpg1uoEenJ8gAm14srPS+dHcvRTFUNnjo7j/c4t5+MZCOrr7uPZ/NvLnTYf55MIpnrEBf5fPm0xHt3O2szd3Ssndy3QHRu88vXsA0X3OOa4KoS0HA5cebjvSyF3PF7MyP5X/uKRgwP/H7efncdt5uXT3Obhmse+M7xx7LBaBAwEGZN/dV8c/dh7l3y8u8BnETouP5pMLJ/P01kr21Th3OAtVoAdnenJJdjLvnyLQf+2JbXzx0S0BH/P/OXqbNyWB5NhIT/qmu9fhXJMpQPpxdYGdxo4en4UET+W3b+wnMSaCH14+m1tWTefZ7VVsOTQ6JbJuIw70IpIArAIeBjDGdBtjRm/xjzPMnWtn8f1PzB7rZgyb+6PsXc8X860nd5KXFsfL31gZsGfrNiHSSnZKLCmu3OhomzkpnpqWrn4TkXZVNpE+MYp0v527hmrNrHRe//Yqvrwqh7T4KL66OnfAc8/JSSYxJoKXd/uugFhS3UrChAgyXG2YnhqLiG8Ne0l1KxbBs5TEnMkTiYm0BszT17Z0ctuftpKREM3vr1846IzU/7y0gL9/ZTk3+q3IGmVz/o72BxgveO2jGmIjrdwS4Hf8pRU5dPY4+MmLzrWiQhnoAc7NS6W0pjXgp5mmjm7W769n+5HGgAOfpTWtroqi/ukki0VYlZ/Ke/vqcDgM+2tb6e5zBEw/rsq3YxHnz2EwOyqaeKuklltW5hAfHcHt5+cxKSGau58vDmnZqL9gevTTgTrgjyKyXUQecu0h62+ZiOwUkZdFZE4Q11MfAxaLcOncDHr6HPzgE7P425eXDenF/ZXVuXz30oIhT40PhnvVTP8B2d2VzZ4VPkcqJtLGnZfNYsOdawbcMwCcKaSLZ6fz5t5anwHO0upWCjLiPemB6AgrWUkxfj36FrJTYz0LvEVYLSyeltSv8qart4+vPL6N1s5e1t2wmMSYwdd0FxEWT0sK+HvIS4vrF+iNMbxdUsvKfHvAsaCCjHhWzbDzvmuuQW5aaCcZuleO/SDAXIa3S52rvToM/Wr6wfmGmZ8WN2AxwuqCNBrau9ld1UxxlfNvxX9mNzir1y6Ymc7ftlQMukzEb9/YR1JMhOeNNCbSxvc/MYuPjrXwxChupxhMoLcBi4D7jTELgXbgDr9ztgHTjDHzgd8DzwV6IhG5VUSKRKSorm7816SqU/vh5bP58PsXcvPKnCEH7msLs07beIJnNq7XgGzziR7K69uZP8z8fDDWnjWJ1q5eT5ByL2vgn0rItcf6TJoKdM4505MpqW71VPIYY/jBs3vYeriRX14zz/PmFoz89DgO1bf7zEHYe6yVY82dp6yUumWls+ooJtLq+aQSKrMnTSQpJiJgnv614hrs8VFEWi0BB6sDVdx4WzXDXWZZR/HRZmJdnzwD+cLybBrau/nnroHXqN92pJF3Suu4ZVUOcV6VR584axLLclL41aulo7YaZjCBvhKoNMZsdt1/Gmfg9zDGtBhj2ly3XwIiRKRfCYkxZp0xptAYU2i324NokhoPoiOsJEyIGOtmDMh7kpbbyTVMRj915HZubirx0TZe2u2sB69sPEFbV2+/oJyXFkd5nXPgsKO7l8PHOyhI9z3Hk6d39Vwffv8gT22t5OsX5HH5vKEvO30qM9KdcxAOec3sdVdYrZ458Ot2RV4qMzPifT6phIrFIizPS+WDA/U+k7Pcq71eMied+VkJbC73TWs1tndT29oVMD/vlhwbybzMRN7ZV8ueoy3MmZwwYJnzuXkp5NpjT7l5/G/f2E9ybCQ3eu0BAc5PUT+6cg5tXb386rXSQf/PIzHiQG+MqQYqRMQ9urMG8Fm0XUQyxPWbFZElruuFfr62UsPgPUnLzb1GznArboIRabNw0ax0Xv+ohp4+h+eNx7+XmWuPo6vXwdGmE+yvacOY/ufMy0wgyubsub5dWstPX9rLpXMyQroxhntMwHsphDf31nDWlATS4gfuqYsIj31xCX/47KIBzwnGyrxUalq6fMYxPjhQT0d3HxfPzuCc6SnsOdriM/u3xPOzPvUnndUz7OyoaGJPVTNzpgx8rohw4/JsdlY2e3Z087bl0HHe21fHratyAs4jmJEez7cvmsGKEZZSDybYqpt/Ax4XkV3AAuCnInKbiNzmevxqYI+I7AR+B1xnRnPxbKWGaGbGREprWj0DYLsqmpmWEjOkPHYorT1rEs0nethY1tCvtNLNvRvTgbq2AStFomxWFk5N5NXiar7+xHYKMiby68/MH/ZEu1PJtcchAvtda9Mfb+9me0XTkCa4pQ2wsF4orHDl6b3TN69/VEN8lI2lOSkszXEuB+Kdpy/1q24aiHtCVFevY9Adtj61KJO4KBuPbjjkc7y5o4dvPbmDKYkTTrni7e3n53HZWcObAT9UQQV6Y8wOV8plnjHmk8aYRmPMA8aYB1yP32eMmWOMmW+MWWqM2RCaZisVnJmT4unscXgWGNtV2XRa0zZuK/NTiY208vKeY+w91kJW8gSf/C14lVjWtlFS3Up0hIWpAdYmWjI9hcrGE0RFWHjoxkJiIkO7V0J0hJWpyTGeHv27+2oxZuQzmUMlMymG6amxngHfPofhjb01rJ6ZRqTNwqJpidgs4pOnL61pJSkmgrT4qIGeFnCm8pJcc1rmnqJHD85VY69enMk/dx/zVAEZY/j3p3ZS09LJfZ9dGPLfyVDpzFh1RprltV9tXWsXR5s7T+tArFt0hJULZqXzanENHx1r6Zd7B2euODk2krK6NkprWpiRHh+wp37x7HQmJ0TzwOcWD1i/H6z8tHhPj/7NvbWkxkWdcmvL0+XcvBQ2lTfQ0+dg+5FG6tu6uXi2c52emEgb8zJ98/QlftVNA7FahNUFaUyIsA6peuyGZdPo6TP89UPnngcPrT/IG3tr+N5ls/rt8nY6aaBXZ6T89DgsAnurWz0TXcaiRw9w2dwMjrd3U17XPmAqIdceS1ltO6XVbRSkBz5n7pQENty5pt++q6GUnx7Hwfp2Onv6eG9fHecX2EOaHhqpFXl2Orr72H6kidc+qiHCKqwuODlAfE5OCrsqnQsFOhzu6qahVSLdedlMHr/lnCGtHptjj2PVDDuPbz7MpvIGfv5KCWvnZvAFv3kJp5sGenVGio6wkp0aS8mxFnZWNmORwT+aj5bzCuxERzhfigOV++Xa49hd1Ux9W9cpSwJHW35aHD19hme2VdHS2TvmaRu3ZbkpWATe31/Ha8XVLMtNJT76ZOXXOdOT6XUYth5upLLxBB3dfUP+OabFR7NoGL3xLyyfRk1LFzc88iGZSRO45+p5o7ra7VBooFdnrFkZzv1qd1U2kZ8WP2b505hIG+e79ieeNSlw8MlLi+OEazJOKGriRyo/zdm+B9eXE2EVz0DoWEuY4NyP4MmiCg41dHjSNm7uzVM2lx/3VFuN1hvmeTPSPGMof/jsIiZGj32p8dj8ZSs1DszMiOefu4/R2NHNpXMyxrQt7sll01MD54G988Nj2aPPS3NW3hysb+fcvBSfXvNYW5GXyu/fOgDARX6BPi7KxtwpCWw+2ECUawbvQCmwYFktwoM3FNLR3X/rzLGiPXp1xnIvBdza2cu8YS5kFmqLpyVx32cXDTiT2B3oU2IjsQ9SKTKaJkRaPcv0uj+FjBfuGvT5WYkB1ytaOj2ZnRXN7Kxs6rfZSKgVZMSP6eCrPw306ozlPfA5FhU3wzElaQKRNsuY9ubd3Omb8ZKfd1s4NYnMpAk+6+h7Oycnme4+B2+X1o2Ln+PppKkbdcbKTHLWrHf3OsY07z0UVotw0/Jszw5ZY2nt3AyibJZTLto2FiJtFt7/zwsGfLwwOxmLOOvsB5soFW400Kszlogwd8pEevvMaduFKxh3XjZrrJsAwDWFWVxTmDXWzRi2idERzJ48kT1Vp17MLBxpoFdntN9+ZiEGXZXjTHHO9BT2VLVoj16pM0lGQmiXzVXj2+eXTmNChJWcAaqbwpUGeqXUGSM7NZbvnGI7xXA1/hOTSimlgqKBXimlwpwGeqWUCnMa6JVSKswFFehFJFFEnhaREhHZKyLL/B4XEfmdiBwQkV0iMjp7iSmllBpQsFU39wKvGGOuFpFIwH/bm7VAvuvrHOB+179KKaVOkxH36EUkAVgFPAxgjOk2xjT5nXYl8Jhx2gQkisjobIqolFIqoGBSN9OBOuCPIrJdRB4SkVi/c6YAFV73K13HfIjIrSJSJCJFdXV1QTRJKaWUv2BSNzZgEfBvxpjNInIvcAfww+E+kTFmHbAOQETqRORwEO1KBeoHPWvsfVzaCdrW0aJtHR1nalunDfRAMIG+Eqg0xmx23X8aZ6D3VgV4r36U6To2IGOM/VSPD0ZEiowxhcE8x+nwcWknaFtHi7Z1dGhb+xtx6sYYUw1UiIh7PvEa4CO/014AbnBV3ywFmo0xx0Z6TaWUUsMXbNXNvwGPuypuyoGbROQ2AGPMA8BLwGXAAaADuCnI6ymllBqmoAK9MWYH4P+x4wGvxw1wezDXGIF1p/l6I/VxaSdoW0eLtnV0aFv9iDMWK6WUCle6BIJSSoU5DfRKKRXmwibQi8ilIlLqWlfHv8xzTInIIyJSKyJ7vI4li8jrIrLf9W/SWLbRTUSyRORtEflIRIpF5Buu4+OuvSISLSIfishOV1t/5Do+XUQ2u/4WnnQVC4w5EbG6Jhe+6Lo/Xtt5SER2i8gOESlyHRt3v38IvN7WeGyriBS4fp7urxYR+ebpamtYBHoRsQJ/wLm2zmzgehGZPbat8vG/wKV+x+4A3jTG5ANv0n8OwljpBf7dGDMbWArc7vpZjsf2dgEXGGPmAwuAS11lvPcAvzHG5AGNwJfGsI3evgHs9bo/XtsJcL4xZoFXjfd4/P3DyfW2ZgLzcf58x11bjTGlrp/nAmAxzirEZzldbTXGfOy/gGXAq1737wTuHOt2+bUxG9jjdb8UmOS6PQkoHes2DtDu54GLxnt7cS6otw3nonn1gC3Q38YYti/T9UK+AHgRkPHYTldbDgGpfsfG3e8fSAAO4ioqGc9t9WvfxcAHp7OtYdGjZ4hr6owz6ebk5LFqIH0sGxOIiGQDC4HNjNP2utIhO4Ba4HWgDGgyxvS6Thkvfwu/Bb4LOFz3Uxif7QQwwGsislVEbnUdG4+//4HW2xqPbfV2HfAX1+3T0tZwCfQfa8b5dj6u6lxFJA74O/BNY0yL92Pjqb3GmD7j/DicCSwBZo5xk/oRkcuBWmPM1rFuyxCtMMYswpkKvV1EVnk/OI5+/+71tu43xiwE2vFLfYyjtgLgGoe5AnjK/7HRbGu4BPphr6kzDtS4l2x2/Vs7xu3xEJEInEH+cWPMM67D47a9AMa5RPbbOFMgiSLingw4Hv4WzgWuEJFDwF9xpm/uZfy1EwBjTJXr31qceeQljM/ff6D1thYxPtvqthbYZoypcd0/LW0Nl0C/Bch3VTFE4vxo9MIYt2kwLwA3um7fiDMXPuZERHDuMbDXGPNrr4fGXXtFxC4iia7bE3COJezFGfCvdp025m01xtxpjMk0xmTj/Nt8yxjzr4yzdgKISKyIxLtv48wn72Ec/v7NwOttjbu2ermek2kbOF1tHeuBiRAOcFwG7MOZo/3+WLfHr21/AY4BPTh7IV/CmaN9E9gPvAEkj3U7XW1dgfPj4y5gh+vrsvHYXmAesN3V1j3AXa7jOcCHONdYegqIGuu2erV5NfDieG2nq007XV/F7tfSePz9u9q1AChy/Q08BySN47bGAg1Agtex09JWXQJBKaXCXLikbpRSSg1AA71SSoU5DfRKKRXmNNArpVSY00CvlFJhTgO9UkqFOQ30SikV5v4//7LvMN+FUuoAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"%matplotlib inline\n",
"\n",
"plt.figure()\n",
"plt.plot(losses)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 预测\n",
"用训练好的模型进行预测。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"the input words is: of, william\n",
"the predict words is: shakespeare\n",
"the true words is: shakespeare\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Python/3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
],
"source": [
"import random\n",
"def test(model):\n",
" model.eval()\n",
" # 从最后10组数据中随机选取1个\n",
" idx = random.randint(len(trigram)-10, len(trigram)-1)\n",
" print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n",
" x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n",
" x_data = paddle.to_tensor(np.array(x_data))\n",
" predicts = model(x_data)\n",
" predicts = predicts.numpy().tolist()[0]\n",
" predicts = predicts.index(max(predicts))\n",
" print('the predict words is: ' + idx_to_word[predicts])\n",
" y_data = trigram[idx][1]\n",
" print('the true words is: ' + y_data)\n",
"test(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册