diff --git a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2e5ba0a50754fc1c2a46afec281d2761039d82e9 --- /dev/null +++ b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb @@ -0,0 +1,640 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MNIST数据集使用LeNet进行图像分类\n", + "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", + "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop版本。" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0.0\n" + ] + } + ], + "source": [ + "import paddle\n", + "print(paddle.__version__)\n", + "paddle.disable_static()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 加载数据集\n", + "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download training data and load training data\n", + "load finished\n" + ] + } + ], + "source": [ + "print('download training data and load training data')\n", + "train_dataset = paddle.vision.datasets.MNIST(mode='train')\n", + "test_dataset = paddle.vision.datasets.MNIST(mode='test')\n", + "print('load finished')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "取训练集中的一条数据看一下。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_data0 label is: [5]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkjCFmRdj3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H97ghom3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A/ZGSLFUZBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbO5tLBIAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", + "train_data0 = train_data0.reshape([28,28])\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(train_data0, cmap=plt.cm.binary)\n", + "print('train_data0 label is: ' + str(train_label_0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2.组网\n", + "用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn.functional as F\n", + "class LeNet(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n", + " self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n", + " self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = F.relu(x)\n", + " x = self.max_pool1(x)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = F.relu(x)\n", + " x = self.linear2(x)\n", + " x = F.relu(x)\n", + " x = self.linear3(x)\n", + " x = F.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 训练方式一\n", + "通过`Model` 构建实例,快速完成模型训练" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "from paddle.static import InputSpec\n", + "from paddle.metric import Accuracy\n", + "inputs = InputSpec([None, 784], 'float32', 'x')\n", + "labels = InputSpec([None, 10], 'float32', 'x')\n", + "model = paddle.hapi.Model(LeNet(), inputs, labels)\n", + "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", + "\n", + "\n", + "model.prepare(\n", + " optim,\n", + " paddle.nn.loss.CrossEntropyLoss(),\n", + " Accuracy(topk=(1, 2))\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 使用model.fit来训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/2\n", + "step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", + " return (isinstance(seq, collections.Sequence) and\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n", + "step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n", + "step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n", + "step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n", + "step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n", + "step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n", + "step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n", + "step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n", + "step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n", + "step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n", + "step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n", + "step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n", + "step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n", + "step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n", + "step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n", + "step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n", + "step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n", + "step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n", + "step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n", + "step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n", + "step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n", + "step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n", + "step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n", + "step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n", + "step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n", + "step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n", + "step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n", + "step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n", + "step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n", + "step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n", + "step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n", + "step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n", + "step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n", + "step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n", + "step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n", + "step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n", + "step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n", + "step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n", + "step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n", + "step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n", + "step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n", + "step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n", + "step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n", + "step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n", + "step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n", + "step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n", + "step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n", + "step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n", + "step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n", + "step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n", + "step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n", + "step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n", + "step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n", + "step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n", + "step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n", + "step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n", + "step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n", + "step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n", + "step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n", + "step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n", + "step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n", + "step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n", + "step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n", + "step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n", + "step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n", + "step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n", + "step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n", + "step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n", + "step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n", + "step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n", + "step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n", + "step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n", + "step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n", + "step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n", + "step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n", + "step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n", + "step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n", + "step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n", + "step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n", + "step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n", + "step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n", + "step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n", + "step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n", + "step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n", + "step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n", + "step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n", + "step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n", + "step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n", + "step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n", + "step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n", + "step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n", + "step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n", + "step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n", + "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n", + "Eval begin...\n", + "step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n", + "step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n", + "step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n", + "step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n", + "step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n", + "step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n", + "step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n", + "step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n", + "step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n", + "step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n", + "step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n", + "step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n", + "step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n", + "step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n", + "step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n", + "Eval samples: 10000\n", + "Epoch 2/2\n", + "step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n", + "step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n", + "step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n", + "step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n", + "step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n", + "step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n", + "step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n", + "step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n", + "step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n", + "step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n", + "step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n", + "step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n", + "step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n", + "step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n", + "step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n", + "step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n", + "step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n", + "step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n", + "step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n", + "step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n", + "step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n", + "step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n", + "step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n", + "step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n", + "step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n", + "step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", + "step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n", + "step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n", + "step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n", + "step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n", + "step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", + "step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n", + "step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n", + "step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n", + "step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n", + "step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n", + "step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n", + "step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n", + "step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n", + "step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n", + "step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n", + "step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n", + "step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n", + "step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n", + "step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n", + "step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n", + "step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n", + "step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n", + "step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", + "step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n", + "step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", + "step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n", + "step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n", + "step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n", + "step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n", + "step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n", + "step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n", + "step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", + "step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", + "step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n", + "step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n", + "step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n", + "step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n", + "step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n", + "step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n", + "step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n", + "step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n", + "step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n", + "step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n", + "step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n", + "step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n", + "step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n", + "step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n", + "step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n", + "step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n", + "step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n", + "step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n", + "step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n", + "step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n", + "step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n", + "step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n", + "step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n", + "step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n", + "step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n", + "step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n", + "step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n", + "step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n", + "step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n", + "step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n", + "step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n", + "step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n", + "step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n", + "step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n", + "step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n", + "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n", + "Eval begin...\n", + "step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n", + "step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n", + "step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n", + "step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n", + "step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n", + "step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n", + "step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n", + "step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n", + "step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n", + "step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n", + "step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n", + "step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n", + "step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n", + "step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n", + "step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n", + "Eval samples: 10000\n", + "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n" + ] + } + ], + "source": [ + "model.fit(train_dataset,\n", + " test_dataset,\n", + " epochs=2,\n", + " batch_size=64,\n", + " save_dir='mnist_checkpoint')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练方式1结束\n", + "以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3.训练方式2\n", + "方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n", + "epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n", + "epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n", + "epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n", + "epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n", + "epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n", + "epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n", + "epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n", + "epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n", + "epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n", + "epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n", + "epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n", + "epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n", + "epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n", + "epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n", + "epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n", + "epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n", + "epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n", + "epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n", + "epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n" + ] + } + ], + "source": [ + "import paddle\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def train(model):\n", + " model.train()\n", + " epochs = 2\n", + " batch_size = 64\n", + " optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = LeNet()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对模型进行验证" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_id: 0, loss is: [1.5017498], acc is: [1.]\n", + "batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n", + "batch_id: 200, loss is: [1.4958509], acc is: [1.]\n", + "batch_id: 300, loss is: [1.4924574], acc is: [1.]\n", + "batch_id: 400, loss is: [1.4762049], acc is: [1.]\n", + "batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n", + "batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n", + "batch_id: 700, loss is: [1.4954926], acc is: [1.]\n", + "batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n", + "batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n" + ] + } + ], + "source": [ + "import paddle\n", + "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def test(model):\n", + " model.eval()\n", + " batch_size = 64\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + "test(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练方式2结束\n", + "以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..30bb2126ed232dde3be662933008cd9f98dd2956 --- /dev/null +++ b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb @@ -0,0 +1,487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## 用N-Gram模型在莎士比亚文集中训练word embedding\n", + "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", + "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", + "本示例在莎士比亚文集上实现了trigram。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'0.0.0'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据集&&相关参数\n", + "训练数据集采用了莎士比亚文集,[下载](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt),保存为txt格式即可。
\n", + "context_size设为2,意味着是trigram。embedding_dim设为256。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-09-03 08:41:10-- https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\n", + "正在解析主机 ocw.mit.edu (ocw.mit.edu)... 151.101.230.133\n", + "正在连接 ocw.mit.edu (ocw.mit.edu)|151.101.230.133|:443... 已连接。\n", + "已发出 HTTP 请求,正在等待回应... 200 OK\n", + "长度:5458199 (5.2M) [text/plain]\n", + "正在保存至: “t8.shakespeare.txt.1”\n", + "\n", + "t8.shakespeare.txt. 100%[===================>] 5.21M 26.1KB/s 用时 4m 14s \n", + "\n", + "2020-09-03 08:45:25 (21.0 KB/s) - 已保存 “t8.shakespeare.txt.1” [5458199/5458199])\n", + "\n" + ] + } + ], + "source": [ + "!wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "embedding_dim = 256\n", + "context_size = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 5458199 characters\n" + ] + } + ], + "source": [ + "# 文件路径\n", + "path_to_file = './t8.shakespeare.txt'\n", + "test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", + "\n", + "# 文本长度是指文本中的字符个数\n", + "print ('Length of text: {} characters'.format(len(test_sentence)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 去除标点符号\n", + "用`string`库中的punctuation,完成英文符号的替换。" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n" + ] + } + ], + "source": [ + "from string import punctuation\n", + "process_dicts={i:'' for i in punctuation}\n", + "print(process_dicts)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "28343\n" + ] + } + ], + "source": [ + "punc_table = str.maketrans(process_dicts)\n", + "test_sentence = test_sentence.translate(punc_table)\n", + "test_sentence = test_sentence.lower().split()\n", + "vocab = set(test_sentence)\n", + "print(len(vocab))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据预处理\n", + "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[['this', 'is'], 'the'], [['is', 'the'], '100th'], [['the', '100th'], 'etext']]\n" + ] + } + ], + "source": [ + "trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n", + " for i in range(len(test_sentence) - 2)]\n", + "\n", + "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", + "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", + "# 看一下数据集\n", + "print(trigram[:3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 构建`Dataset`类 加载数据\n", + "用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "batch_size = 256\n", + "paddle.disable_static()\n", + "class TrainDataset(paddle.io.Dataset):\n", + " def __init__(self, tuple_data):\n", + " self.tuple_data = tuple_data\n", + "\n", + " def __getitem__(self, idx):\n", + " data = self.tuple_data[idx][0]\n", + " label = self.tuple_data[idx][1]\n", + " data = np.array(list(map(lambda w: word_to_idx[w], data)))\n", + " label = np.array(word_to_idx[label])\n", + " return data, label\n", + " \n", + " def __len__(self):\n", + " return len(self.tuple_data)\n", + "train_dataset = TrainDataset(trigram)\n", + "train_loader = paddle.io.DataLoader(train_dataset,places=paddle.fluid.cpu_places(), return_list=True,\n", + " shuffle=True, batch_size=batch_size, drop_last=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 组网&训练\n", + "这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "hidden_size = 1024\n", + "class NGramModel(paddle.nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, context_size):\n", + " super(NGramModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, hidden_size)\n", + " self.linear2 = paddle.nn.Linear(hidden_size, len(vocab))\n", + "\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [-1, context_size * embedding_dim])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定义`train()`函数,对模型进行训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [10.252256]\n", + "epoch: 0, batch_id: 100, loss is: [7.0485706]\n", + "epoch: 0, batch_id: 200, loss is: [7.282592]\n", + "epoch: 0, batch_id: 300, loss is: [6.9604626]\n", + "epoch: 0, batch_id: 400, loss is: [6.7308316]\n", + "epoch: 0, batch_id: 500, loss is: [6.7940483]\n", + "epoch: 0, batch_id: 600, loss is: [6.6574802]\n", + "epoch: 0, batch_id: 700, loss is: [6.862562]\n", + "epoch: 0, batch_id: 800, loss is: [7.2091002]\n", + "epoch: 0, batch_id: 900, loss is: [7.0172606]\n", + "epoch: 0, batch_id: 1000, loss is: [6.9888105]\n", + "epoch: 0, batch_id: 1100, loss is: [6.9609995]\n", + "epoch: 0, batch_id: 1200, loss is: [6.550024]\n", + "epoch: 0, batch_id: 1300, loss is: [6.714109]\n", + "epoch: 0, batch_id: 1400, loss is: [6.995716]\n", + "epoch: 0, batch_id: 1500, loss is: [6.939434]\n", + "epoch: 0, batch_id: 1600, loss is: [6.5966253]\n", + "epoch: 0, batch_id: 1700, loss is: [6.9880104]\n", + "epoch: 0, batch_id: 1800, loss is: [6.6459093]\n", + "epoch: 0, batch_id: 1900, loss is: [6.8095036]\n", + "epoch: 0, batch_id: 2000, loss is: [6.8447037]\n", + "epoch: 0, batch_id: 2100, loss is: [6.8313]\n", + "epoch: 0, batch_id: 2200, loss is: [6.808483]\n", + "epoch: 0, batch_id: 2300, loss is: [6.502908]\n", + "epoch: 0, batch_id: 2400, loss is: [6.561283]\n", + "epoch: 0, batch_id: 2500, loss is: [7.0093765]\n", + "epoch: 0, batch_id: 2600, loss is: [6.512396]\n", + "epoch: 0, batch_id: 2700, loss is: [6.809763]\n", + "epoch: 0, batch_id: 2800, loss is: [6.806659]\n", + "epoch: 0, batch_id: 2900, loss is: [6.95402]\n", + "epoch: 0, batch_id: 3000, loss is: [6.634927]\n", + "epoch: 0, batch_id: 3100, loss is: [6.644098]\n", + "epoch: 0, batch_id: 3200, loss is: [6.705504]\n", + "epoch: 0, batch_id: 3300, loss is: [6.2121572]\n", + "epoch: 0, batch_id: 3400, loss is: [6.638401]\n", + "epoch: 0, batch_id: 3500, loss is: [6.986831]\n", + "epoch: 1, batch_id: 0, loss is: [6.795429]\n", + "epoch: 1, batch_id: 100, loss is: [6.582568]\n", + "epoch: 1, batch_id: 200, loss is: [6.527663]\n", + "epoch: 1, batch_id: 300, loss is: [6.714637]\n", + "epoch: 1, batch_id: 400, loss is: [6.574902]\n", + "epoch: 1, batch_id: 500, loss is: [6.305031]\n", + "epoch: 1, batch_id: 600, loss is: [6.803609]\n", + "epoch: 1, batch_id: 700, loss is: [6.2429113]\n", + "epoch: 1, batch_id: 800, loss is: [6.7452283]\n", + "epoch: 1, batch_id: 900, loss is: [6.383783]\n", + "epoch: 1, batch_id: 1000, loss is: [6.4906135]\n", + "epoch: 1, batch_id: 1100, loss is: [6.6007314]\n", + "epoch: 1, batch_id: 1200, loss is: [6.63466]\n", + "epoch: 1, batch_id: 1300, loss is: [6.540749]\n", + "epoch: 1, batch_id: 1400, loss is: [6.7752547]\n", + "epoch: 1, batch_id: 1500, loss is: [6.2411666]\n", + "epoch: 1, batch_id: 1600, loss is: [6.540929]\n", + "epoch: 1, batch_id: 1700, loss is: [6.6563463]\n", + "epoch: 1, batch_id: 1800, loss is: [6.4592104]\n", + "epoch: 1, batch_id: 1900, loss is: [7.0268345]\n", + "epoch: 1, batch_id: 2000, loss is: [6.803793]\n", + "epoch: 1, batch_id: 2100, loss is: [6.8454733]\n", + "epoch: 1, batch_id: 2200, loss is: [6.651756]\n", + "epoch: 1, batch_id: 2300, loss is: [6.5876465]\n", + "epoch: 1, batch_id: 2400, loss is: [6.258934]\n", + "epoch: 1, batch_id: 2500, loss is: [6.5422425]\n", + "epoch: 1, batch_id: 2600, loss is: [6.184501]\n", + "epoch: 1, batch_id: 2700, loss is: [6.6847773]\n", + "epoch: 1, batch_id: 2800, loss is: [6.684101]\n", + "epoch: 1, batch_id: 2900, loss is: [6.374978]\n", + "epoch: 1, batch_id: 3000, loss is: [6.8277273]\n", + "epoch: 1, batch_id: 3100, loss is: [6.5195084]\n", + "epoch: 1, batch_id: 3200, loss is: [6.311832]\n", + "epoch: 1, batch_id: 3300, loss is: [6.4282994]\n", + "epoch: 1, batch_id: 3400, loss is: [6.603338]\n", + "epoch: 1, batch_id: 3500, loss is: [6.4541807]\n" + ] + } + ], + "source": [ + "import time\n", + "vocab_size = len(vocab)\n", + "epochs = 2\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " y_data = paddle.nn.functional.one_hot(y_data, len(vocab))\n", + " loss = paddle.nn.functional.softmax_with_cross_entropy(predicts, y_data,soft_label=True)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " losses.append(avg_loss.numpy())\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy())) \n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = NGramModel(vocab_size, embedding_dim, context_size)\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 打印loss下降曲线\n", + "通过可视化loss的曲线,可以看到模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXzcVbn48c8zM1mapVknSdukTbM03egaSlvaUihbkQuigOBVEAVE8bpdrxdc4Kf+firqVVG8cMviBQVFkE1k3wtdaLo3NGmbdEnSZm32NOuc3x+zdGYyaZaZNGH6vF+vvDrznW/me5pknjnznOecI8YYlFJKhS/LWDdAKaXU6NJAr5RSYU4DvVJKhTkN9EopFeY00CulVJizjXUD/KWmpprs7OyxboZSSn2sbN26td4YYw/02LgL9NnZ2RQVFY11M5RS6mNFRA4P9JimbpRSKsxpoFdKqTCngV4ppcKcBnqllApzGuiVUirMaaBXSqkwp4FeKaXCXNgE+mPNJ/j1a6WU17WNdVOUUmpcCZtAX9faxe/eOkB5XftYN0UppcaVsAn0kTbnf6W7zzHGLVFKqfElfAK91RXoezXQK6WUt/AJ9DYN9EopFUjYBfouTd0opZSPsAn0UVYroD16pZTyFzaB3tOj7+0b45YopdT4EnaBXnv0SinlK2wCvdUiWC2igV4ppfwMGuhF5BERqRWRPV7HkkXkdRHZ7/o3aYDvvdF1zn4RuTGUDQ8k0mrRQK+UUn6G0qP/X+BSv2N3AG8aY/KBN133fYhIMnA3cA6wBLh7oDeEUIm0WXTClFJK+Rk00Btj3gOO+x2+EnjUdftR4JMBvvUS4HVjzHFjTCPwOv3fMEIqyqY9eqWU8jfSHH26MeaY63Y1kB7gnClAhdf9StexUROpgV4ppfoJejDWGGMAE8xziMitIlIkIkV1dXUjfp5Im0UnTCmllJ+RBvoaEZkE4Pq3NsA5VUCW1/1M17F+jDHrjDGFxphCu90+wibpYKxSSgUy0kD/AuCuorkReD7AOa8CF4tIkmsQ9mLXsVGjOXqllOpvKOWVfwE2AgUiUikiXwJ+DlwkIvuBC133EZFCEXkIwBhzHPgJsMX19WPXsVGjOXqllOrPNtgJxpjrB3hoTYBzi4Cbve4/Ajwy4tYNU6TNQmePBnqllPIWNjNjQXP0SikVSHgFek3dKKVUP2EW6K06M1YppfyEV6DX1I1SSvUTXoHeZqFLA71SSvkIq0DvrKPXjUeUUspbWAV6Xb1SKaX6C69A78rRO5ffUUopBeEW6G0WHAZ6HRrolVLKLewCPei+sUop5S28Ar1VA71SSvkLr0Dv7tHrgKxSSnmEZ6DXHr1SSnmEVaCPcgV6nTSllFInhWWg1x69UkqdFFaBXnP0SinVX1CBXkS+ISJ7RKRYRL4Z4PHVItIsIjtcX3cFc73BRFqtgPbolVLK26A7TA1EROYCtwBLgG7gFRF50RhzwO/U9caYy4No45DpYKxSSvUXTI9+FrDZGNNhjOkF3gU+FZpmjczJ1I0ubKaUUm7BBPo9wEoRSRGRGOAyICvAectEZKeIvCwicwI9kYjcKiJFIlJUV1c34gbphCmllOpvxKkbY8xeEbkHeA1oB3YA/l3pbcA0Y0ybiFwGPAfkB3iudcA6gMLCwhEvVBOp5ZVKKdVPUIOxxpiHjTGLjTGrgEZgn9/jLcaYNtftl4AIEUkN5pqnouWVSinVX7BVN2muf6fizM8/4fd4hoiI6/YS1/UagrnmqWh5pVJK9Tfi1I3L30UkBegBbjfGNInIbQDGmAeAq4GviEgvcAK4zoziYvGao1dKqf6CCvTGmJUBjj3gdfs+4L5grjEcWl6plFL9hefMWA30SinlEVaB3mYRRDRHr5RS3sIq0IuIZ99YpZRSTmEV6MGZvtE6eqWUOinsAn2UBnqllPIRdoFeUzdKKeUr/AK9zaKDsUop5SXsAn2UzUp3r65eqZRSbmEX6CNtmrpRSilv4RnoNXWjlFIe4RfodTBWKaV8hF+g19SNUkr5CMtAr3X0Sil1UlgGes3RK6XUSWEX6KM0R6+UUj6C3WHqGyKyR0SKReSbAR4XEfmdiBwQkV0isiiY6w2F5uiVUsrXiAO9iMwFbgGWAPOBy0Ukz++0tTg3A88HbgXuH+n1hkpTN0op5SuYHv0sYLMxpsMY0wu8i3PfWG9XAo8Zp01AoohMCuKag9LySqWU8hVMoN8DrBSRFBGJAS4DsvzOmQJUeN2vdB3zISK3ikiRiBTV1dUF0SRN3SillL8RB3pjzF7gHuA14BVgBzCiRWaMMeuMMYXGmEK73T7SJgHOQN/rMDgco7YHuVJKfawENRhrjHnYGLPYGLMKaAT2+Z1ShW8vP9N1bNR49o3VPL1SSgHBV92kuf6dijM//4TfKS8AN7iqb5YCzcaYY8FcczCRVud/SSdNKaWUky3I7/+7iKQAPcDtxpgmEbkNwBjzAPASztz9AaADuCnI6w0qyt2j10CvlFJAkIHeGLMywLEHvG4b4PZgrjFc7tRNl65Jr5RSQBjOjI3UHr1SSvkIv0BvtQI6GKuUUm7hF+i1R6+UUj7CLtDrYKxSSvkKu0CvPXqllPIVtoG+S3P0SikFhGOgt2qPXimlvIVdoNccvVJK+Qq7QK85eqWU8hW+gV5z9EopBYRjoNccvVJK+Qi/QK+pG6WU8hG+gV5TN0opBYRjoNf16JVSykfYBXoR0Q3ClVLKS7A7TH1LRIpFZI+I/EVEov0e/4KI1InIDtfXzcE1d2h0g3CllDppxIFeRKYAXwcKjTFzAStwXYBTnzTGLHB9PTTS6w1HpM1Cd59uPKKUUhB86sYGTBARGxADHA2+ScHT1I1SSp004kBvjKkCfgUcAY7h3Pj7tQCnflpEdonI0yKSFei5RORWESkSkaK6urqRNslDUzdKKXVSMKmbJOBKYDowGYgVkc/5nfYPINsYMw94HXg00HMZY9YZYwqNMYV2u32kTfKItFm06kYppVyCSd1cCBw0xtQZY3qAZ4Dl3icYYxqMMV2uuw8Bi4O43pBp6kYppU4KJtAfAZaKSIyICLAG2Ot9gohM8rp7hf/jo8U5GKuBXimlwDmYOiLGmM0i8jSwDegFtgPrROTHQJEx5gXg6yJyhevx48AXgm/y4DR1o5RSJ4040AMYY+4G7vY7fJfX43cCdwZzjZGIsllo7ew93ZdVSqlxKexmxoIz0GuOXimlnMIy0GuOXimlTgrPQK9VN0op5RGegV5TN0op5RG+gV5TN0opBYRroLdatUevlFIu4RnoNXWjlFIe4Rvo+xwYY8a6KUopNebCMtBH6b6xSinlEZaB3r1vrKZvlFIqXAO9TQO9Ukq5hXeg19SNUkqFaaDX1I1SSnmEZ6DX1I1SSnmEdaDXNemVUkoDvVJKhb2gAr2IfEtEikVkj4j8RUSi/R6PEpEnReSAiGwWkexgrjdUUZqjV0opjxEHehGZAnwdKDTGzAWswHV+p30JaDTG5AG/Ae4Z6fWGQ6tulFLqpGBTNzZggojYgBjgqN/jVwKPum4/DaxxbSQ+qqJsVkB79EopBUEEemNMFfAr4AhwDGg2xrzmd9oUoMJ1fi/QDKT4P5eI3CoiRSJSVFdXN9ImeWjVjVJKnRRM6iYJZ499OjAZiBWRz43kuYwx64wxhcaYQrvdPtImeZxM3fQF/VxKKfVxF0zq5kLgoDGmzhjTAzwDLPc7pwrIAnCldxKAhiCuOSTao1dKqZOCCfRHgKUiEuPKu68B9vqd8wJwo+v21cBb5jSsHawzY5VS6qRgcvSbcQ6wbgN2u55rnYj8WESucJ32MJAiIgeAbwN3BNneIdE6eqWUOskWzDcbY+4G7vY7fJfX453ANcFcYyR0PXqllDopPGfGaupGKaU8wjLQWyyCzSIa6JVSijAN9KAbhCullFt4B3rN0SulVBgHeqv26JVSCsI50GvqRimlgDAP9F2aulFKqTAO9Jq6UUopIIwDfZSmbpRSCgjjQB9ps9DVq6tXKqVUWAd67dErpVQ4B3qr1tErpRSEc6DXHr1SSgFhHOijbFYN9EopRRgHeu3RK6WUUzB7xhaIyA6vrxYR+abfOatFpNnrnLsGer5Q07VulFLKacQbjxhjSoEFACJixbk/7LMBTl1vjLl8pNcZqUirRXeYUkopQpe6WQOUGWMOh+j5gqYTppRSyilUgf464C8DPLZMRHaKyMsiMifQCSJyq4gUiUhRXV1dSBrkTt2chr3IlVJqXAs60ItIJHAF8FSAh7cB04wx84HfA88Feg5jzDpjTKExptButwfbJMCZujEGeh0a6JVSZ7ZQ9OjXAtuMMTX+DxhjWowxba7bLwERIpIagmsOKtKm+8YqpRSEJtBfzwBpGxHJEBFx3V7iul5DCK45KA30SinlNOKqGwARiQUuAr7sdew2AGPMA8DVwFdEpBc4AVxnTlPS3BPotcRSKXWGCyrQG2PagRS/Yw943b4PuC+Ya4xUpFV79EopBWE+MxbQWnql1BkvbAN9lObolVIKCONArzl6pZRyCt9Ab7UCvj36zp4+mjt6xqpJSik1JsI30AdI3dzx911c9d8f6GxZpdQZJewDvXvf2Mb2bl7aXU15fTvl9e1j2TSllDqtwjfQ+5VXvrDzqCdfv35faNbTUUqpj4PwDfR+g7FPb61k9qSJZKfEsH5//ahe+/HNh9l6+PioXkMppYYqbAN9lFcdfUl1C7urmrmmMJOV+XY2ljeMWtllW1cvdz1fzB1/341DF1RTSo0DYR/ou3sdPF1USYRVuHLBFFbmp9LR3ce2I42jct0tB4/T5zDsr23jnX21o3INpZQajrAN9O7UzYnuPp7bUcWamekkx0ayLDcFq0VYv3908vQbyuqJtFqYlBDNA++Wj8o1lFJqOMI+0L/2UTX1bd1cvTgTgPjoCBZNTRy1PP3G8gYWTE3k5pU5fHjwONtH6ZODUkoNVfgGelfVzZZDjaTGRXFewckNTVbm29ld1czx9u6QXrO5o4fioy0sy0nhurOzSJgQwbr3tFevlBpbYRvobVYLFnHevmrhZCKsJ/+rK/NTMQbePxDaXv3mgw0YA8tzU4iNsvG5pVN5pbiag1q3r5QaQ2Eb6OFk+uaawiyf4/MyE0mYEBHyevqN5Q1E2SwsmJoIwI3Ls4mwWnhwvfbqlVJjZ8SBXkQKRGSH11eLiHzT7xwRkd+JyAER2SUii4Jv8tBFWi3Mz0xgRnq8z3GrRViRl8r6/fUhXQ5hY1kDhdlJRNmc6+ykxUfz6UVTeHprJXWtXSG7jlJKDceIA70xptQYs8AYswBYDHQAz/qdthbId33dCtw/0uuNxDcvnMH3LpsV8LGV+alUt3RyoLYtJNc63t5NSXUry3J89mHh5pU59PQ5WPdemdbVK+VSfLSZ7zy1U5cRP01ClbpZA5QZYw77Hb8SeMw4bQISRWRSiK45qC+umM45foHXbUW+c4/y90JUfbOp3LkV7rJc3+vl2uO4bO4kHlx/kOU/f4u7n9/DhrJ6ekOwfPLB+nb6PoZvHj19Dq5bt5F3SnWewZnqn7uO8fTWSt7YWzPWTTkjhCrQX0fgDcKnABVe9ytdx3yIyK0iUiQiRXV1p2cdmsykGHLssSGrp99Y1kBMpJV5mYn9Hvuva+fzm8/MZ35WAn/dUsFnH9zMkp++yXef3skbH9XQ2dM37OttP9LI+b96h0/dv4GS6pYhfc++mlaueWADjSGuNhqufTWtbCo/zmsfjf8XefOJHm585EP21bSOdVPCSlmd85P0Xz48MsYtOTMEHehFJBK4AnhqpM9hjFlnjCk0xhTa7fbBvyFEVuXb2VTeMKJA629jeQNnZyf7VPe4RUdYuWphJv/z+UK233UR9//rIlbkpfLy7mpufqyIRT95ndsf30Zta+eQr/fynmoirELl8Q4u/937/PLVkkH/Hy/tPsaWQ428N0qTxYaquMr5xlRybGhvUAOpa+2io7s3FE0a0Hv76nh3Xx2/fWPfqF7nTFNe144IrN9fT8XxjrFuTtgLRY9+LbDNGBOoe1YFeJe8ZLqOjQurC+x09jiCTiHUtjpz/f5pm0BiIm2sPWsSv7t+IVt/eBGPfnEJVy2cwivF1Tzy/qEhXc8Yw6vF1SzPTeWNb5/HlQum8Ie3y1h773r2niJ4bjvSBMCGAw1Dus5o2XO0GYB9NW0jHgw/3t7Nxb95lx+98FEom9bPRldK7uU91RzSMtmQ6O1zcKihnSvmT8Yi8OSWisG/SQUlFIH+egKnbQBeAG5wVd8sBZqNMcdCcM2QWJGXSvrEKP5WVBnU82wqd65U6T8QO5hIm4XzZtj5f1edxZzJE9lRMbRZtPtq2jjc0MHFc9JJio3kv66dz5+/dA5NHd38+vXAPU+Hw7D9sPP5N5SP7uqdg9ld5Qz0bV29VDaeGNFz/PLVEho7enizpHZUN5LZVNbA/KxELZMNoYrGE/T0GVbkpbK6II2ntlaEZMxKDSyoQC8iscBFwDNex24Tkdtcd18CyoEDwIPAV4O5XqjZrBY+vSiTd0prqWkZetrE38ayBuKjbMyZPHHEz7EgK5Hdlc1DGlx9rbgaEbhoVrrn2Ir8VC6Zk8Gm8oaAz7G/to3Wrl7mZyZQcfzEmH1c7u1zsPdYCwtdcw1Kqoef+95R0cRft1QwPTWW+rYu9h4bnfx5TUsn5fXtfOKsDD69KJOntEw2JMpclW65aXFcd3YWNS1dvF2qe0SMpqACvTGm3RiTYoxp9jr2gDHmAddtY4y53RiTa4w5yxhTFGyDQ+3awiwcxrle/UhtKm9gyfRkbAHy80O1ICuR9u4+9tcOHrRe+6iGhVmJpE2M9jm+LDeF1s5ePjraP32z1dWbv/38PMD55jQWyura6exx8KlFzrWHSoc4kOzW5zDc9fweUuOieOjGQoBRG3PwVFLlpHLLyun09Dl4bOOhUbnWmcQ9EJubGscFM9NIi4/SQdlRFtYzY4ciOzWWJdOTeaqoYkQpgKNNJzhY3z6k/PypLMhy9nB3uPLop7re7qpmLp6T0e8xdxs2lPVPzWw70khybCQXzkonNS4q4Dmnwx5X2mbp9GSykiewd5g9+ie3VLCrspnvXzaLXHscMzPieW+UdgzbWNbAxGgbsydPJMcex8Wz03ls42Hau0Z3ADjclde1kxoXSUJMBDarhWsKnZ+qjzWPLI2nBnfGB3qAzxRmcaihgy2Hhr/S5N+KKhCBC73SKCMxPTWWhAkR7Kg4daB/3VWSePHs/tdLi48mPy2ODQF669sON7JoahIWi7A8N4UNZQ2nfGNzOAzv76/n9se3seKet7h+3Sa+/+xuHlpfztultTSf6Bnm/9Bpd1UzEyKs5NjjKEifSOkwAn1jeze/eLWEJdOTuXLBZABWzbBTdKhxVKpvNpY3sGS6c1lrgC+fl0vziR4dPAxSWV0bOfY4z/3PFE7FYeBvW5yfqrt7HfxtSwVr713PbwYYcxprH7cxBdtYN2A8WHtWBne/UMyTWypYMj15yN/X3evgz5uOcH5BGtmpsUG1QUSYn5U4aKB/tbiavLQ4nxeKt+W5KTy1tZLuXodnrZ/j7d2U17dzdWGm55wXdh6lrK6NvDTf5SEa27v5y5Yj/PXDCo4c7yAxJoJzc1M52nyCF3cd8wR4i8DcKQksy03h3NxUluemDCl1VXy0mdmTJ2K1CLMmxfN2aS2dPX1ER1h9zuvpc1ByrBWbVYiOsBIdYeG3r++ntbOXn1w5FxFn8F2Vb2fde+VsKm/ggpnBvdl6O9p0gsMNHdywLNtzbNHUJJZkJ/Pw+wf5/LJpAUtp1eDK6tq4dO7JeZNTU2JYkZfK34oqiIu28dD6co41dxIbaeV/3ivjxuXZJMdGjmGLfe2ubObTD2zgH19bQUFG/ODfMA7oXyrOksd/mT+Jl3Yfo7XTt6f64cHjnnSDv5d2H6O+rYsvLM8OSTsWZCWyr6Z1wNRAU0c3mw8eD9ibd1uW69xBa1flyTcM95r4i6cmAbA81zkr2L/n3+cwXP/gJn7xSimTEqK597oFbLpzDX/410U8+9Vz2XHXRWz9wYU8ccs5fO2CfKJsFh5ef5AbHvmQn71cMuj/z+EwFB9tYa5r0LogI54+hwm4DMWD68v5l/veZ+296zn/V++w7Gdv8WRRBTcuy/Z5cRVmJxEdYeG9faFNRbnHMJbm+L7xf/m8HKqaTvDirqMhvZ6/po5urrjv/VOWy34cHW/vprGjh1y7b8fo+iVTqWo6wU9e/IispBj+eNPZPHv7uXT2OPjTRv8J92NrQ1k93b0OXt4zbgoIB6U9epdrC7P4y4cV/HPXMa5bMpXOnj5+9tJeHt14mKSYCN749nmkxEX5fM8fNxwi1x7LStdyCsFamJWIw8CuyuaAOf+3SmrpcxguCZCfd1uak4yIM4gXZjuD1NbDjdgs4pm1m5U8gSmJE9hwoMGnx/r3bZWUVLdy73ULuHJBvwnMiAgpcVEsj4tyvllcNIP2rl6+9sQ2Xtx1lB98Ypanpx1IeX07Hd19zJ2SAMDMDGfAL61u9Rxze7W4hpkZ8XxjTT6dvX109jiIsFq4fJ7vChrREVbOmZ4S8gHZjeUNJMZEMCvDt5Lq/II08tLi+OMHh7hqYWZIr+lt25FGdlU2s35/HbMmjbyaKxRe/6iGyYnRzJmcMPjJg/AMxKb5fiK9ZE46d6ydSeG0JM/fLcD5BXYe23iIL5+X0+9T31jZ4yp2eKe0jm9eOGOMWzM02qN3WZCVSH5aHE8WVVBa3cqV933AoxsPc83iTNq6evnRP3wn5mw/0sjOiia+sDz7lMFtOOa7B2QHSN+8VlxDxsRozpoy8AsuMSaS2ZMm+gy2bj3cyJzJE5kQ6XyhiAjn5qWwsbzBs9BaZ08fv3l9H/OzErli/uQhtzk2ysbauZOoaRm8zLHYNVHKHdSzU2KItFn6LeFQ19rFzoomLp83ibVnTeKqhZlcv2QqVy/ODPhiXzXDTnldO5WNoSsZ3VTewDnTk7FYfH+3Fotww7Jp7KpsZucgabZguMtOT7XonjFm1BfKaz7Rw+1PbOM3r+8PyfOVe1XceLNZLdx2Xq5PkAe4dVUuDe3d/H1bcHNdQqnY9Ql/Z2VTyDcvGi0a6F1EhGsLs9h+pIl/+f37NLR38783nc0vr5nP7efn8cLOo7zptQDTHz84RHyUzVMmGArJsZFMS4kJOHGqs6ePd/fVcdHs9H7Bx9/y3BS2HWmis6ePnj4HOyubWOhK25w8J5XmEz185EoNPLrhEMeaO7nj0pnDfuNy79719iAzjHdXNhNps5Dn6s3ZrBZmpMf1q6V3P8/5M9OGdv0ZrgXqQpS+qTjeQWXjiQEnwF21cAoxkVb+tGn0UgolrjfNsrqBZ+P+/q0DrL13/ahOGHtx11G6ex0cGELZ71CU1bUTabMwJWnCkM5fmpPMvMwEHlp/cFys/trW1cvBhnbWzEzDGEZt7+lQ00Dv5apFU5gYbePcvBRe+eZKVhc4A81XV+cxIz2OHzy3h9bOHmpaOnlp9zGuPTuL2KjQZr8WDDAgu35/PSd6+rh4zuADjstzU+nudbDtcCMlx1rp7HGweJpvoHenhjaWNdDc0cMf3j7A6gL7iMpE0ydGM2fyxEGXkthztJlZkyb6DGIWpE/sF+jf2ltLxsRoZg8xZZFrj2NyQvSQyixPdPcNGjA2elYiDZySi4+O4KqFU/jHzqNDWiDOGMOfNh0e1k5jpV49+oEC+YayekprWjkyipPfnnLNGj98vCMka0KV1baRkxrrqWQajIhwy8ocDta38/o4WOly77EWjHGOKaTERvJ2ycdjBVYN9F5S46LY8oML+eNNS0j1ysdH2izc8+l5VLd0cs8rJTy+6TB9xnDDsmkhb8OCrERqWrp8aoqNMTz4XjmpcVGcM33wQHz29GSsFmFDWQNbDzuXZ/AP9OkTo8m1x7KhrJ7/fvcArV29fPeSmSNu9/kFaWw70kRzR+CyS4fDUFx1ciDWbWZGPHWtXZ6PwF29fazfX8cFs9KG/MlCRFiZb+eDUyz/3Ocw3P9OGfN+9Cp/HaQ8clNZA8mxkeSnBa5sAvjc0ml09TqGNNHugwMN/PC5PXzlz1vpGUJZXnevg7K6NuKjbTSf6AmYHjDGeN4MNh88PuhzjsSB2lZ2VDSxcGoixpzMrwejrK6N3AEqxgaydm4GmUkTeHCU9l+ub+viUH07Rxo6qDjeQVXTiQFnqLsLM+ZlJrBqhp339tePi08ag9FA78e9O5S/hVOTuGn5dP686QiPfHCINTPTmJYSXEllIO6JU97537dLa/nw0HG+sSbPUzJ5KnFRNuZnJrChrJ6tR5rImBjN5MT+H5WX56ayqfw4//vBIa5aMIXZQSzhcP5MO30Ow/oDgXvVR4530NrV22/QdeYkZwWNO0//4cHjtHf3sWaIaRu3VTPstHb2srOy/6ehIw0dXLduI/e8UkKvw7D54MCzgo0xbCxvYGlO//y8t1mTJnJ2dhJ/3nx40Bf6f79zgAkRVkqqW3lo/cFB/y9ldW30OgwXz3YOugfK09e1dtHoelPdXD46gf7prVVYLcJ/XFwwYDuGo6u3jyPHO8ixD+91Y7NauHnFdIoON3pmeIdKdXMnS3/6Jqt/9Q6rfvk2K3/xNuf+/C1+8NzugOfvqWrBHh9F2sRoVhfYOd7eza4BqvLGEw30w/CdS2aQmTSBtq5evrB8+qhcY/bkiURaLWx3Bfo+h+EXr5SSnRLDdUumDvl5luemsrOymY1lDf168yfPSeFETx/GwLcuCq56YEFWEokxEbxdEjjQu1es9B9IdpdKunPSb+6tJcpm8ZSADtWKvFQsAu+68vTGGI63d/PXD4+w9t73KDnWym8+M5/VM+yeawVyuKGDY82dQ1qg7nNLp3G4oYP1p9hkfkdFExvKGvjWRflcMiede9/cx+GGU6dw3D31T8xzBvpAeXp3uis1LooPD4V+OYvePgfPbKvk/AI7hdnOT4j7a4YW6Dt7+vjes7spOuT7BnSkoQOHYdg9enDu+5wwIYL73tof0s12dlU20eswfLhCvFcAABhRSURBVOfiGfzXNfP55dXzWJaTwqvFNQHfwIuPNns+la7MtyPCx2IDHQ30wxATaeOBzy3m2xfN4Ny84JY8GEiUzcqsyRM9SyE8t72KkupWvnNJwbAm6CzPTaHPYahv62LRAIF+aU4KkTYLNyybRlZyTFDttlqEVfl23t1XG/AFsruqmQirkJ/u+yK3x0WREhtJaXUrxhjeKqnl3LxUT4XQUCXERDA/K5E/fnCQFfe8RcEPXmHRT17njmd2c1ZmAq98axVXLcxk5qSJlNW1DbiF3UA7hQVy6dwMUuMi+dPGQwOec/87B0iYEMFnz5nGj66Yi81i4QfP7TnlAGpJdSsRVuHcvFSiIywBUybuN4PPLsmi4vgJjjYNf/mAutYu/ufdMu58ZhdtfnM31h+op7a1i6sXZxJps5CdEjPkzVd+/OJHPLH5CD/6x0c+/09PaeUIAn1slI3bzsvl7dI6PvvgppAtl+D+Od507nQ+vTiTawqzuG5JVsCeemdPH/tr2zyfSpNjI1mQlfixWJBNA/0wzZ2SwNfX5IespDKQhVmJ7K5qpqO7l1+/vo+zpiRw2dzh7cC4aFqSJ80zUI8+KTaSN799HnesHXlu3tv5M+3Ut3V7eu/eiqtaKMiI75caExEKMuIpqW6hrK6dI8c7uGCYaRu3W1bmcNaUBAqnJXHTimzuunw2D99YyBM3L2WKK3U1MyOeXocZMN+8/UgTCRMihhSMomxWPnN2Fm+W1AZcDfRAbSuvFtdw47JpxEXZyEiI5ruXFrB+fz3P7Rh4W4aS6hZy7XFE2azkpMYFbGtJdStp8VFcMtfZ6/9wiHn63j4Hb3xUwy2PFbH0Z2/ys5dL+OuWCm59rMhnsPXprZUkxUR4ZhvPSI8fUurm2e2VPLH5CDMz4tld1ezTLvcnk+GmbtxuOy+HX10zn91Vzay9dz2vFVeP6Hm8ldS0MjU5xqeoYqCeekl1K30O47NK7eoZaeyqbKKhbfirmo50GZGR0EA/Ds3PSqCju4+7ni+mqukEd6ydOWhJpb/oCCuLpyYRZbOcsnolKzkmqFU3va1yvUD80zfGGPYcbWbuABNuZmZMZF9Nm2cdn5EG+svOmsQTtyzlt9ct5M61s/jiiumsmeVbjuqefDTQ9os7K5uYl5kw5Dfy65dMRQi8Jd7975QzIcLKF849meb713OmsXBqIj95ce+AFTul1a3MdKW08tLiAgbY0hrnG+fMjInER9tOOe7g1t3r4LMPbebmx4rYfqSJm1dM541vn8evr53PhrIGvvbEdnr6HDR1dPN6cQ1XLpji6Szkp8VxqKGdrt6BK2/21bTyvWf2sGR6Mk9/ZTlJMRE86DUmUVbbxqSE6BFXqokIVy/O5MV/W8GUxAnc+qet3P38nqAGQ0uOtfRbxiA5NpL5mYm849dTdw/Eek8cW11gd5VZDq+0d09VM2f/3zd4+P3Bx2xCQQP9OLQgy9kDf3prJSvzUzk3b2Qzb7910Qx+fOWcIQ3ghkJKXBTzMxP71dNXNp6gqaOHOQNM9JqZEc+Jnj7+vOkwsyZNDDhwHCrTU2OJtFoC5ulPdDs/ms8PsO/vQDKTYlgzK52H3j/IH94+4EkJVTWd4PkdVVy3JMtnnRarRfjZp86i5UQPv3i1/7IRzR09HGvuZKbrDSnXHkdV0wlOdJ8MsH0Ow/6aNmZmxGO1CEuyk4dUefOjfxTz4cHj/OSTc9l45wXcedks8tLiuGphJj+5cg5v7K3hP57ayfM7jtLd5+DqxSfniOSlx+MwDFgi2t7Vy1f+vJXYKBv3Xb+QuCgbn186jTdLajyTpMrq20fcm/eWY4/jma8u54Zl03h04+ERz4ru7OnjUEOH503V2+oCe78JUcVHm0mYEEGm1xyAs6YkkBIbOew8/W/f2Ed3n4Nfv1Ya1F4YQxXsxiOJIvK0iJSIyF4RWeb3+GoRaRaRHa6vu4Jr7pkhOyWGxJgIAP7z0pGnVZZMT+YzZw99ADcU3C8Q90fZ2tZO/s8LxQAsGCCAuntUVU0nhl1tM1wRVueErUDLIxcfdW78Mi9zeFP9f3rVWayZmcYvXy3lE79bz4cHj3tKAW9emdPv/JkZE7l6cSbP7zjarza91JUHd/9MctNiMX4B1tmzdlDgWp5hyfRkyuvaT7nn8F8/PMLjm49w23m5fH5p/wXZPr8sm/+4pIDndhzl//1zL7MmTfSpkHKXmu4LMCBrjOGOZ3ZzsL6d31+/0LNPwueWTSPCYuGRDw5ijKG8dvillQOJsln53mWziLRZTtmbbuvqHbDk90BtG30OE3BhstUF/SdE7alqYe6UiT6f9iwWYdUMO+/uqxvyIPGuyibe2FvLtYWZ9PQZfj6EdaKCFWxX717gFWPMTGA+sDfAOeuNMQtcXz8O8npnBBHhM2dnccvK6f3KEce7810vkHf31fHs9kou+vV7rD9Qzw8+MYuzBgigM9Ljcb92Lpg1uoEenJ8gAm14srPS+dHcvRTFUNnjo7j/c4t5+MZCOrr7uPZ/NvLnTYf55MIpnrEBf5fPm0xHt3O2szd3Ssndy3QHRu88vXsA0X3OOa4KoS0HA5cebjvSyF3PF7MyP5X/uKRgwP/H7efncdt5uXT3Obhmse+M7xx7LBaBAwEGZN/dV8c/dh7l3y8u8BnETouP5pMLJ/P01kr21Th3OAtVoAdnenJJdjLvnyLQf+2JbXzx0S0BH/P/OXqbNyWB5NhIT/qmu9fhXJMpQPpxdYGdxo4en4UET+W3b+wnMSaCH14+m1tWTefZ7VVsOTQ6JbJuIw70IpIArAIeBjDGdBtjRm/xjzPMnWtn8f1PzB7rZgyb+6PsXc8X860nd5KXFsfL31gZsGfrNiHSSnZKLCmu3OhomzkpnpqWrn4TkXZVNpE+MYp0v527hmrNrHRe//Yqvrwqh7T4KL66OnfAc8/JSSYxJoKXd/uugFhS3UrChAgyXG2YnhqLiG8Ne0l1KxbBs5TEnMkTiYm0BszT17Z0ctuftpKREM3vr1846IzU/7y0gL9/ZTk3+q3IGmVz/o72BxgveO2jGmIjrdwS4Hf8pRU5dPY4+MmLzrWiQhnoAc7NS6W0pjXgp5mmjm7W769n+5HGgAOfpTWtroqi/ukki0VYlZ/Ke/vqcDgM+2tb6e5zBEw/rsq3YxHnz2EwOyqaeKuklltW5hAfHcHt5+cxKSGau58vDmnZqL9gevTTgTrgjyKyXUQecu0h62+ZiOwUkZdFZE4Q11MfAxaLcOncDHr6HPzgE7P425eXDenF/ZXVuXz30oIhT40PhnvVTP8B2d2VzZ4VPkcqJtLGnZfNYsOdawbcMwCcKaSLZ6fz5t5anwHO0upWCjLiPemB6AgrWUkxfj36FrJTYz0LvEVYLSyeltSv8qart4+vPL6N1s5e1t2wmMSYwdd0FxEWT0sK+HvIS4vrF+iNMbxdUsvKfHvAsaCCjHhWzbDzvmuuQW5aaCcZuleO/SDAXIa3S52rvToM/Wr6wfmGmZ8WN2AxwuqCNBrau9ld1UxxlfNvxX9mNzir1y6Ymc7ftlQMukzEb9/YR1JMhOeNNCbSxvc/MYuPjrXwxChupxhMoLcBi4D7jTELgXbgDr9ztgHTjDHzgd8DzwV6IhG5VUSKRKSorm7816SqU/vh5bP58PsXcvPKnCEH7msLs07beIJnNq7XgGzziR7K69uZP8z8fDDWnjWJ1q5eT5ByL2vgn0rItcf6TJoKdM4505MpqW71VPIYY/jBs3vYeriRX14zz/PmFoz89DgO1bf7zEHYe6yVY82dp6yUumWls+ooJtLq+aQSKrMnTSQpJiJgnv614hrs8VFEWi0BB6sDVdx4WzXDXWZZR/HRZmJdnzwD+cLybBrau/nnroHXqN92pJF3Suu4ZVUOcV6VR584axLLclL41aulo7YaZjCBvhKoNMZsdt1/Gmfg9zDGtBhj2ly3XwIiRKRfCYkxZp0xptAYU2i324NokhoPoiOsJEyIGOtmDMh7kpbbyTVMRj915HZubirx0TZe2u2sB69sPEFbV2+/oJyXFkd5nXPgsKO7l8PHOyhI9z3Hk6d39Vwffv8gT22t5OsX5HH5vKEvO30qM9KdcxAOec3sdVdYrZ458Ot2RV4qMzPifT6phIrFIizPS+WDA/U+k7Pcq71eMied+VkJbC73TWs1tndT29oVMD/vlhwbybzMRN7ZV8ueoy3MmZwwYJnzuXkp5NpjT7l5/G/f2E9ybCQ3eu0BAc5PUT+6cg5tXb386rXSQf/PIzHiQG+MqQYqRMQ9urMG8Fm0XUQyxPWbFZElruuFfr62UsPgPUnLzb1GznArboIRabNw0ax0Xv+ohp4+h+eNx7+XmWuPo6vXwdGmE+yvacOY/ufMy0wgyubsub5dWstPX9rLpXMyQroxhntMwHsphDf31nDWlATS4gfuqYsIj31xCX/47KIBzwnGyrxUalq6fMYxPjhQT0d3HxfPzuCc6SnsOdriM/u3xPOzPvUnndUz7OyoaGJPVTNzpgx8rohw4/JsdlY2e3Z087bl0HHe21fHratyAs4jmJEez7cvmsGKEZZSDybYqpt/Ax4XkV3AAuCnInKbiNzmevxqYI+I7AR+B1xnRnPxbKWGaGbGREprWj0DYLsqmpmWEjOkPHYorT1rEs0nethY1tCvtNLNvRvTgbq2AStFomxWFk5N5NXiar7+xHYKMiby68/MH/ZEu1PJtcchAvtda9Mfb+9me0XTkCa4pQ2wsF4orHDl6b3TN69/VEN8lI2lOSkszXEuB+Kdpy/1q24aiHtCVFevY9Adtj61KJO4KBuPbjjkc7y5o4dvPbmDKYkTTrni7e3n53HZWcObAT9UQQV6Y8wOV8plnjHmk8aYRmPMA8aYB1yP32eMmWOMmW+MWWqM2RCaZisVnJmT4unscXgWGNtV2XRa0zZuK/NTiY208vKeY+w91kJW8gSf/C14lVjWtlFS3Up0hIWpAdYmWjI9hcrGE0RFWHjoxkJiIkO7V0J0hJWpyTGeHv27+2oxZuQzmUMlMymG6amxngHfPofhjb01rJ6ZRqTNwqJpidgs4pOnL61pJSkmgrT4qIGeFnCm8pJcc1rmnqJHD85VY69enMk/dx/zVAEZY/j3p3ZS09LJfZ9dGPLfyVDpzFh1RprltV9tXWsXR5s7T+tArFt0hJULZqXzanENHx1r6Zd7B2euODk2krK6NkprWpiRHh+wp37x7HQmJ0TzwOcWD1i/H6z8tHhPj/7NvbWkxkWdcmvL0+XcvBQ2lTfQ0+dg+5FG6tu6uXi2c52emEgb8zJ98/QlftVNA7FahNUFaUyIsA6peuyGZdPo6TP89UPnngcPrT/IG3tr+N5ls/rt8nY6aaBXZ6T89DgsAnurWz0TXcaiRw9w2dwMjrd3U17XPmAqIdceS1ltO6XVbRSkBz5n7pQENty5pt++q6GUnx7Hwfp2Onv6eG9fHecX2EOaHhqpFXl2Orr72H6kidc+qiHCKqwuODlAfE5OCrsqnQsFOhzu6qahVSLdedlMHr/lnCGtHptjj2PVDDuPbz7MpvIGfv5KCWvnZvAFv3kJp5sGenVGio6wkp0aS8mxFnZWNmORwT+aj5bzCuxERzhfigOV++Xa49hd1Ux9W9cpSwJHW35aHD19hme2VdHS2TvmaRu3ZbkpWATe31/Ha8XVLMtNJT76ZOXXOdOT6XUYth5upLLxBB3dfUP+OabFR7NoGL3xLyyfRk1LFzc88iGZSRO45+p5o7ra7VBooFdnrFkZzv1qd1U2kZ8WP2b505hIG+e79ieeNSlw8MlLi+OEazJOKGriRyo/zdm+B9eXE2EVz0DoWEuY4NyP4MmiCg41dHjSNm7uzVM2lx/3VFuN1hvmeTPSPGMof/jsIiZGj32p8dj8ZSs1DszMiOefu4/R2NHNpXMyxrQt7sll01MD54G988Nj2aPPS3NW3hysb+fcvBSfXvNYW5GXyu/fOgDARX6BPi7KxtwpCWw+2ECUawbvQCmwYFktwoM3FNLR3X/rzLGiPXp1xnIvBdza2cu8YS5kFmqLpyVx32cXDTiT2B3oU2IjsQ9SKTKaJkRaPcv0uj+FjBfuGvT5WYkB1ytaOj2ZnRXN7Kxs6rfZSKgVZMSP6eCrPw306ozlPfA5FhU3wzElaQKRNsuY9ubd3Omb8ZKfd1s4NYnMpAk+6+h7Oycnme4+B2+X1o2Ln+PppKkbdcbKTHLWrHf3OsY07z0UVotw0/Jszw5ZY2nt3AyibJZTLto2FiJtFt7/zwsGfLwwOxmLOOvsB5soFW400Kszlogwd8pEevvMaduFKxh3XjZrrJsAwDWFWVxTmDXWzRi2idERzJ48kT1Vp17MLBxpoFdntN9+ZiEGXZXjTHHO9BT2VLVoj16pM0lGQmiXzVXj2+eXTmNChJWcAaqbwpUGeqXUGSM7NZbvnGI7xXA1/hOTSimlgqKBXimlwpwGeqWUCnMa6JVSKswFFehFJFFEnhaREhHZKyLL/B4XEfmdiBwQkV0iMjp7iSmllBpQsFU39wKvGGOuFpFIwH/bm7VAvuvrHOB+179KKaVOkxH36EUkAVgFPAxgjOk2xjT5nXYl8Jhx2gQkisjobIqolFIqoGBSN9OBOuCPIrJdRB4SkVi/c6YAFV73K13HfIjIrSJSJCJFdXV1QTRJKaWUv2BSNzZgEfBvxpjNInIvcAfww+E+kTFmHbAOQETqRORwEO1KBeoHPWvsfVzaCdrW0aJtHR1nalunDfRAMIG+Eqg0xmx23X8aZ6D3VgV4r36U6To2IGOM/VSPD0ZEiowxhcE8x+nwcWknaFtHi7Z1dGhb+xtx6sYYUw1UiIh7PvEa4CO/014AbnBV3ywFmo0xx0Z6TaWUUsMXbNXNvwGPuypuyoGbROQ2AGPMA8BLwGXAAaADuCnI6ymllBqmoAK9MWYH4P+x4wGvxw1wezDXGIF1p/l6I/VxaSdoW0eLtnV0aFv9iDMWK6WUCle6BIJSSoU5DfRKKRXmwibQi8ilIlLqWlfHv8xzTInIIyJSKyJ7vI4li8jrIrLf9W/SWLbRTUSyRORtEflIRIpF5Buu4+OuvSISLSIfishOV1t/5Do+XUQ2u/4WnnQVC4w5EbG6Jhe+6Lo/Xtt5SER2i8gOESlyHRt3v38IvN7WeGyriBS4fp7urxYR+ebpamtYBHoRsQJ/wLm2zmzgehGZPbat8vG/wKV+x+4A3jTG5ANv0n8OwljpBf7dGDMbWArc7vpZjsf2dgEXGGPmAwuAS11lvPcAvzHG5AGNwJfGsI3evgHs9bo/XtsJcL4xZoFXjfd4/P3DyfW2ZgLzcf58x11bjTGlrp/nAmAxzirEZzldbTXGfOy/gGXAq1737wTuHOt2+bUxG9jjdb8UmOS6PQkoHes2DtDu54GLxnt7cS6otw3nonn1gC3Q38YYti/T9UK+AHgRkPHYTldbDgGpfsfG3e8fSAAO4ioqGc9t9WvfxcAHp7OtYdGjZ4hr6owz6ebk5LFqIH0sGxOIiGQDC4HNjNP2utIhO4Ba4HWgDGgyxvS6Thkvfwu/Bb4LOFz3Uxif7QQwwGsislVEbnUdG4+//4HW2xqPbfV2HfAX1+3T0tZwCfQfa8b5dj6u6lxFJA74O/BNY0yL92Pjqb3GmD7j/DicCSwBZo5xk/oRkcuBWmPM1rFuyxCtMMYswpkKvV1EVnk/OI5+/+71tu43xiwE2vFLfYyjtgLgGoe5AnjK/7HRbGu4BPphr6kzDtS4l2x2/Vs7xu3xEJEInEH+cWPMM67D47a9AMa5RPbbOFMgiSLingw4Hv4WzgWuEJFDwF9xpm/uZfy1EwBjTJXr31qceeQljM/ff6D1thYxPtvqthbYZoypcd0/LW0Nl0C/Bch3VTFE4vxo9MIYt2kwLwA3um7fiDMXPuZERHDuMbDXGPNrr4fGXXtFxC4iia7bE3COJezFGfCvdp025m01xtxpjMk0xmTj/Nt8yxjzr4yzdgKISKyIxLtv48wn72Ec/v7NwOttjbu2ermek2kbOF1tHeuBiRAOcFwG7MOZo/3+WLfHr21/AY4BPTh7IV/CmaN9E9gPvAEkj3U7XW1dgfPj4y5gh+vrsvHYXmAesN3V1j3AXa7jOcCHONdYegqIGuu2erV5NfDieG2nq007XV/F7tfSePz9u9q1AChy/Q08BySN47bGAg1Agtex09JWXQJBKaXCXLikbpRSSg1AA71SSoU5DfRKKRXmNNArpVSY00CvlFJhTgO9UkqFOQ30SikV5v4//7LvMN+FUuoAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker\n", + "%matplotlib inline\n", + "\n", + "plt.figure()\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测\n", + "用训练好的模型进行预测。" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the input words is: of, william\n", + "the predict words is: shakespeare\n", + "the true words is: shakespeare\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Library/Python/3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ], + "source": [ + "import random\n", + "def test(model):\n", + " model.eval()\n", + " # 从最后10组数据中随机选取1个\n", + " idx = random.randint(len(trigram)-10, len(trigram)-1)\n", + " print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n", + " x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n", + " x_data = paddle.to_tensor(np.array(x_data))\n", + " predicts = model(x_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " predicts = predicts.index(max(predicts))\n", + " print('the predict words is: ' + idx_to_word[predicts])\n", + " y_data = trigram[idx][1]\n", + " print('the true words is: ' + y_data)\n", + "test(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}