{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST数据集使用LeNet进行图像分类\n", "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境\n", "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop版本。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0.0\n" ] } ], "source": [ "import paddle\n", "print(paddle.__version__)\n", "paddle.disable_static()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 加载数据集\n", "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download training data and load training data\n", "load finished\n" ] } ], "source": [ "print('download training data and load training data')\n", "train_dataset = paddle.vision.datasets.MNIST(mode='train')\n", "test_dataset = paddle.vision.datasets.MNIST(mode='test')\n", "print('load finished')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "取训练集中的一条数据看一下。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data0 label is: [5]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkjCFmRdj3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H97ghom3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A/ZGSLFUZBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbO5tLBIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", "train_data0 = train_data0.reshape([28,28])\n", "plt.figure(figsize=(2,2))\n", "plt.imshow(train_data0, cmap=plt.cm.binary)\n", "print('train_data0 label is: ' + str(train_label_0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2.组网\n", "用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import paddle.nn.functional as F\n", "class LeNet(paddle.nn.Layer):\n", " def __init__(self):\n", " super(LeNet, self).__init__()\n", " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", " self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n", " self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n", " self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.max_pool1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = self.max_pool2(x)\n", " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", " x = self.linear1(x)\n", " x = F.relu(x)\n", " x = self.linear2(x)\n", " x = F.relu(x)\n", " x = self.linear3(x)\n", " x = F.softmax(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 训练方式一\n", "通过`Model` 构建实例,快速完成模型训练" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "from paddle.static import InputSpec\n", "from paddle.metric import Accuracy\n", "inputs = InputSpec([None, 784], 'float32', 'x')\n", "labels = InputSpec([None, 10], 'float32', 'x')\n", "model = paddle.hapi.Model(LeNet(), inputs, labels)\n", "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", "\n", "\n", "model.prepare(\n", " optim,\n", " paddle.nn.loss.CrossEntropyLoss(),\n", " Accuracy(topk=(1, 2))\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 使用model.fit来训练模型" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", " return (isinstance(seq, collections.Sequence) and\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n", "step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n", "step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n", "step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n", "step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n", "step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n", "step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n", "step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n", "step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n", "step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n", "step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n", "step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n", "step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n", "step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n", "step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n", "step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n", "step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n", "step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n", "step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n", "step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n", "step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n", "step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n", "step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n", "step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n", "step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n", "step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n", "step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n", "step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n", "step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n", "step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n", "step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n", "step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n", "step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n", "step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n", "step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n", "step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n", "step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n", "step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n", "step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n", "step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n", "step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n", "step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n", "step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n", "step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n", "step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n", "step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n", "step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n", "step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n", "step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n", "step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n", "step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n", "step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n", "step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n", "step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n", "step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n", "step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n", "step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n", "step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n", "step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n", "step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n", "step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n", "step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n", "step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n", "step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n", "step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n", "step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n", "step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n", "step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n", "step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n", "step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n", "step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n", "step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n", "step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n", "step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n", "step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n", "step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n", "step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n", "step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n", "step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n", "step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n", "step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n", "step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n", "step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n", "step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n", "step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n", "step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n", "step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n", "step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n", "step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n", "step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n", "step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n", "step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n", "step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n", "Eval begin...\n", "step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n", "step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n", "step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n", "step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n", "step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n", "step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n", "step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n", "step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n", "step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n", "step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n", "step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n", "step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n", "step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n", "step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n", "step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n", "Eval samples: 10000\n", "Epoch 2/2\n", "step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n", "step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n", "step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n", "step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n", "step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n", "step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n", "step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n", "step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n", "step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n", "step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n", "step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n", "step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n", "step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n", "step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n", "step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n", "step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n", "step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n", "step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n", "step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n", "step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n", "step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n", "step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n", "step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n", "step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n", "step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n", "step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", "step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n", "step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n", "step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n", "step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n", "step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", "step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n", "step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n", "step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n", "step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n", "step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n", "step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n", "step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n", "step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n", "step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n", "step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n", "step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n", "step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n", "step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n", "step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n", "step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n", "step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n", "step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n", "step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", "step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n", "step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", "step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n", "step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n", "step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n", "step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n", "step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n", "step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n", "step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", "step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", "step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n", "step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n", "step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n", "step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n", "step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n", "step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n", "step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n", "step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n", "step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n", "step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n", "step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n", "step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n", "step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n", "step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n", "step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n", "step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n", "step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n", "step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n", "step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n", "step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n", "step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n", "step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n", "step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n", "step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n", "step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n", "step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n", "step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n", "step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n", "step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n", "step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n", "step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n", "step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n", "step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n", "step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n", "step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n", "Eval begin...\n", "step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n", "step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n", "step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n", "step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n", "step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n", "step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n", "step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n", "step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n", "step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n", "step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n", "step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n", "step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n", "step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n", "step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n", "step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n", "Eval samples: 10000\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n" ] } ], "source": [ "model.fit(train_dataset,\n", " test_dataset,\n", " epochs=2,\n", " batch_size=64,\n", " save_dir='mnist_checkpoint')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练方式1结束\n", "以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.训练方式2\n", "方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n", "epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n", "epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n", "epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n", "epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n", "epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n", "epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n", "epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n", "epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n", "epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n", "epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n", "epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n", "epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n", "epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n", "epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n", "epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n", "epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n", "epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n", "epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n", "epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n" ] } ], "source": [ "import paddle\n", "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", "def train(model):\n", " model.train()\n", " epochs = 2\n", " batch_size = 64\n", " optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", " for epoch in range(epochs):\n", " for batch_id, data in enumerate(train_loader()):\n", " x_data = data[0]\n", " y_data = data[1]\n", " predicts = model(x_data)\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", " avg_loss = paddle.mean(loss)\n", " avg_acc = paddle.mean(acc)\n", " avg_loss.backward()\n", " if batch_id % 100 == 0:\n", " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", " optim.minimize(avg_loss)\n", " model.clear_gradients()\n", "model = LeNet()\n", "train(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 对模型进行验证" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch_id: 0, loss is: [1.5017498], acc is: [1.]\n", "batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n", "batch_id: 200, loss is: [1.4958509], acc is: [1.]\n", "batch_id: 300, loss is: [1.4924574], acc is: [1.]\n", "batch_id: 400, loss is: [1.4762049], acc is: [1.]\n", "batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n", "batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n", "batch_id: 700, loss is: [1.4954926], acc is: [1.]\n", "batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n", "batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n" ] } ], "source": [ "import paddle\n", "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", "def test(model):\n", " model.eval()\n", " batch_size = 64\n", " for batch_id, data in enumerate(train_loader()):\n", " x_data = data[0]\n", " y_data = data[1]\n", " predicts = model(x_data)\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", " avg_loss = paddle.mean(loss)\n", " avg_acc = paddle.mean(acc)\n", " avg_loss.backward()\n", " if batch_id % 100 == 0:\n", " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", "test(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练方式2结束\n", "以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }