{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST数据集使用LeNet进行图像分类\n", "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境\n", "本教程基于paddle-develop编写,如果您的环境不是本版本,请先安装paddle-develop版本。" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0.0\n" ] } ], "source": [ "import paddle\n", "print(paddle.__version__)\n", "paddle.disable_static()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 加载数据集\n", "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download training data and load training data\n", "load finished\n" ] } ], "source": [ "print('download training data and load training data')\n", "train_dataset = paddle.vision.datasets.MNIST(mode='train')\n", "test_dataset = paddle.vision.datasets.MNIST(mode='test')\n", "print('load finished')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "取训练集中的一条数据看一下。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data0 label is: [5]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkjCFmRdj3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H97ghom3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A/ZGSLFUZBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbO5tLBIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", "train_data0 = train_data0.reshape([28,28])\n", "plt.figure(figsize=(2,2))\n", "plt.imshow(train_data0, cmap=plt.cm.binary)\n", "print('train_data0 label is: ' + str(train_label_0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2.组网\n", "用paddle.nn下的API,如`Conv2d`、`Pool2D`、`Linead`完成LeNet的构建。" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import paddle.nn.functional as F\n", "class LeNet(paddle.nn.Layer):\n", " def __init__(self):\n", " super(LeNet, self).__init__()\n", " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", " self.max_pool1 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", " self.max_pool2 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n", " self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n", " self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.max_pool1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = self.max_pool2(x)\n", " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", " x = self.linear1(x)\n", " x = F.relu(x)\n", " x = self.linear2(x)\n", " x = F.relu(x)\n", " x = self.linear3(x)\n", " x = F.softmax(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.训练方式一\n", "组网后,开始对模型进行训练,先构建`train_loader`,加载训练数据,然后定义`train`函数,设置好损失函数后,按batch加载数据,完成模型的训练。" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, batch_id: 0, loss is: [2.3064885], acc is: [0.109375]\n", "epoch: 0, batch_id: 100, loss is: [1.5477252], acc is: [1.]\n", "epoch: 0, batch_id: 200, loss is: [1.5201148], acc is: [1.]\n", "epoch: 0, batch_id: 300, loss is: [1.525354], acc is: [0.953125]\n", "epoch: 0, batch_id: 400, loss is: [1.5201038], acc is: [1.]\n", "epoch: 0, batch_id: 500, loss is: [1.4901408], acc is: [1.]\n", "epoch: 0, batch_id: 600, loss is: [1.4925538], acc is: [0.984375]\n", "epoch: 0, batch_id: 700, loss is: [1.5247533], acc is: [0.96875]\n", "epoch: 0, batch_id: 800, loss is: [1.5365943], acc is: [1.]\n", "epoch: 0, batch_id: 900, loss is: [1.5154861], acc is: [0.984375]\n", "epoch: 1, batch_id: 0, loss is: [1.4988302], acc is: [0.984375]\n", "epoch: 1, batch_id: 100, loss is: [1.493154], acc is: [0.984375]\n", "epoch: 1, batch_id: 200, loss is: [1.4974915], acc is: [1.]\n", "epoch: 1, batch_id: 300, loss is: [1.5089471], acc is: [0.984375]\n", "epoch: 1, batch_id: 400, loss is: [1.5041347], acc is: [1.]\n", "epoch: 1, batch_id: 500, loss is: [1.5145375], acc is: [1.]\n", "epoch: 1, batch_id: 600, loss is: [1.4904011], acc is: [0.984375]\n", "epoch: 1, batch_id: 700, loss is: [1.5121607], acc is: [0.96875]\n", "epoch: 1, batch_id: 800, loss is: [1.5078678], acc is: [1.]\n", "epoch: 1, batch_id: 900, loss is: [1.500349], acc is: [0.984375]\n" ] } ], "source": [ "import paddle\n", "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", "# 加载训练集 batch_size 设为 64\n", "def train(model):\n", " model.train()\n", " epochs = 2\n", " optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", " # 用Adam作为优化函数\n", " for epoch in range(epochs):\n", " for batch_id, data in enumerate(train_loader()):\n", " x_data = data[0]\n", " y_data = data[1]\n", " predicts = model(x_data)\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", " # 计算损失\n", " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", " avg_loss = paddle.mean(loss)\n", " avg_acc = paddle.mean(acc)\n", " avg_loss.backward()\n", " if batch_id % 100 == 0:\n", " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", " optim.minimize(avg_loss)\n", " model.clear_gradients()\n", "model = LeNet()\n", "train(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 对模型进行验证\n", "训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch_id: 0, loss is: [1.4659549], acc is: [1.]\n", "batch_id: 100, loss is: [1.4933192], acc is: [0.984375]\n", "batch_id: 200, loss is: [1.4779761], acc is: [1.]\n", "batch_id: 300, loss is: [1.4919193], acc is: [0.984375]\n", "batch_id: 400, loss is: [1.5036212], acc is: [1.]\n", "batch_id: 500, loss is: [1.4922347], acc is: [0.984375]\n", "batch_id: 600, loss is: [1.4765416], acc is: [0.984375]\n", "batch_id: 700, loss is: [1.4997746], acc is: [0.984375]\n", "batch_id: 800, loss is: [1.4831288], acc is: [1.]\n", "batch_id: 900, loss is: [1.498342], acc is: [0.984375]\n" ] } ], "source": [ "import paddle\n", "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", "# 加载测试数据集\n", "def test(model):\n", " model.eval()\n", " batch_size = 64\n", " for batch_id, data in enumerate(train_loader()):\n", " x_data = data[0]\n", " y_data = data[1]\n", " predicts = model(x_data)\n", " # 获取预测结果\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", " avg_loss = paddle.mean(loss)\n", " avg_acc = paddle.mean(acc)\n", " avg_loss.backward()\n", " if batch_id % 100 == 0:\n", " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", "test(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练方式一结束\n", "以上就是训练方式一,通过这种方式,可以清楚的看到训练和测试中的每一步过程。但是,这种方式句法比较复杂。因此,我们提供了训练方式二,能够更加快速、高效的完成模型的训练与测试。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.训练方式二\n", "通过paddle提供的`Model` 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "import paddle\n", "from paddle.static import InputSpec\n", "from paddle.metric import Accuracy\n", "inputs = InputSpec([None, 784], 'float32', 'x')\n", "labels = InputSpec([None, 10], 'float32', 'x')\n", "model = paddle.Model(LeNet(), inputs, labels)\n", "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", "\n", "model.prepare(\n", " optim,\n", " paddle.nn.loss.CrossEntropyLoss(),\n", " Accuracy(topk=(1, 2))\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 使用model.fit来训练模型" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "step 10/938 - loss: 2.2434 - acc_top1: 0.1344 - acc_top2: 0.3719 - 14ms/step\n", "step 20/938 - loss: 2.0292 - acc_top1: 0.2836 - acc_top2: 0.4633 - 14ms/step\n", "step 30/938 - loss: 1.9341 - acc_top1: 0.3755 - acc_top2: 0.5214 - 14ms/step\n", "step 40/938 - loss: 1.8009 - acc_top1: 0.4469 - acc_top2: 0.5727 - 14ms/step\n", "step 50/938 - loss: 1.8000 - acc_top1: 0.4975 - acc_top2: 0.6125 - 13ms/step\n", "step 60/938 - loss: 1.6335 - acc_top1: 0.5417 - acc_top2: 0.6438 - 14ms/step\n", "step 70/938 - loss: 1.7931 - acc_top1: 0.5708 - acc_top2: 0.6643 - 13ms/step\n", "step 80/938 - loss: 1.6699 - acc_top1: 0.5961 - acc_top2: 0.6846 - 13ms/step\n", "step 90/938 - loss: 1.6832 - acc_top1: 0.6189 - acc_top2: 0.7069 - 13ms/step\n", "step 100/938 - loss: 1.6336 - acc_top1: 0.6409 - acc_top2: 0.7245 - 14ms/step\n", "step 110/938 - loss: 1.6598 - acc_top1: 0.6557 - acc_top2: 0.7376 - 13ms/step\n", "step 120/938 - loss: 1.6348 - acc_top1: 0.6708 - acc_top2: 0.7488 - 13ms/step\n", "step 130/938 - loss: 1.6223 - acc_top1: 0.6851 - acc_top2: 0.7601 - 13ms/step\n", "step 140/938 - loss: 1.5622 - acc_top1: 0.6970 - acc_top2: 0.7694 - 13ms/step\n", "step 150/938 - loss: 1.6455 - acc_top1: 0.7065 - acc_top2: 0.7767 - 14ms/step\n", "step 160/938 - loss: 1.6127 - acc_top1: 0.7154 - acc_top2: 0.7837 - 14ms/step\n", "step 170/938 - loss: 1.5963 - acc_top1: 0.7242 - acc_top2: 0.7898 - 14ms/step\n", "step 180/938 - loss: 1.6485 - acc_top1: 0.7310 - acc_top2: 0.7948 - 14ms/step\n", "step 190/938 - loss: 1.6666 - acc_top1: 0.7368 - acc_top2: 0.7992 - 14ms/step\n", "step 200/938 - loss: 1.7862 - acc_top1: 0.7419 - acc_top2: 0.8030 - 14ms/step\n", "step 210/938 - loss: 1.5479 - acc_top1: 0.7464 - acc_top2: 0.8064 - 14ms/step\n", "step 220/938 - loss: 1.5650 - acc_top1: 0.7515 - acc_top2: 0.8106 - 14ms/step\n", "step 230/938 - loss: 1.5822 - acc_top1: 0.7562 - acc_top2: 0.8141 - 14ms/step\n", "step 240/938 - loss: 1.5966 - acc_top1: 0.7608 - acc_top2: 0.8179 - 14ms/step\n", "step 250/938 - loss: 1.5551 - acc_top1: 0.7650 - acc_top2: 0.8213 - 14ms/step\n", "step 260/938 - loss: 1.5584 - acc_top1: 0.7699 - acc_top2: 0.8249 - 14ms/step\n", "step 270/938 - loss: 1.5933 - acc_top1: 0.7730 - acc_top2: 0.8273 - 14ms/step\n", "step 280/938 - loss: 1.5589 - acc_top1: 0.7769 - acc_top2: 0.8301 - 14ms/step\n", "step 290/938 - loss: 1.6513 - acc_top1: 0.7793 - acc_top2: 0.8315 - 14ms/step\n", "step 300/938 - loss: 1.5929 - acc_top1: 0.7821 - acc_top2: 0.8337 - 14ms/step\n", "step 310/938 - loss: 1.5672 - acc_top1: 0.7849 - acc_top2: 0.8360 - 14ms/step\n", "step 320/938 - loss: 1.5147 - acc_top1: 0.7879 - acc_top2: 0.8381 - 14ms/step\n", "step 330/938 - loss: 1.5697 - acc_top1: 0.7902 - acc_top2: 0.8397 - 14ms/step\n", "step 340/938 - loss: 1.5697 - acc_top1: 0.7919 - acc_top2: 0.8406 - 14ms/step\n", "step 350/938 - loss: 1.6122 - acc_top1: 0.7941 - acc_top2: 0.8423 - 14ms/step\n", "step 360/938 - loss: 1.5934 - acc_top1: 0.7960 - acc_top2: 0.8435 - 14ms/step\n", "step 370/938 - loss: 1.6258 - acc_top1: 0.7982 - acc_top2: 0.8451 - 14ms/step\n", "step 380/938 - loss: 1.6805 - acc_top1: 0.7996 - acc_top2: 0.8463 - 14ms/step\n", "step 390/938 - loss: 1.5997 - acc_top1: 0.8011 - acc_top2: 0.8475 - 14ms/step\n", "step 400/938 - loss: 1.6151 - acc_top1: 0.8029 - acc_top2: 0.8488 - 14ms/step\n", "step 410/938 - loss: 1.5800 - acc_top1: 0.8047 - acc_top2: 0.8499 - 14ms/step\n", "step 420/938 - loss: 1.5950 - acc_top1: 0.8060 - acc_top2: 0.8508 - 14ms/step\n", "step 430/938 - loss: 1.5533 - acc_top1: 0.8075 - acc_top2: 0.8517 - 14ms/step\n", "step 440/938 - loss: 1.6171 - acc_top1: 0.8086 - acc_top2: 0.8521 - 14ms/step\n", "step 450/938 - loss: 1.5756 - acc_top1: 0.8103 - acc_top2: 0.8533 - 14ms/step\n", "step 460/938 - loss: 1.5655 - acc_top1: 0.8121 - acc_top2: 0.8544 - 14ms/step\n", "step 470/938 - loss: 1.5816 - acc_top1: 0.8139 - acc_top2: 0.8555 - 14ms/step\n", "step 480/938 - loss: 1.6202 - acc_top1: 0.8148 - acc_top2: 0.8562 - 14ms/step\n", "step 490/938 - loss: 1.6223 - acc_top1: 0.8157 - acc_top2: 0.8567 - 14ms/step\n", "step 500/938 - loss: 1.5198 - acc_top1: 0.8167 - acc_top2: 0.8574 - 14ms/step\n", "step 510/938 - loss: 1.5853 - acc_top1: 0.8181 - acc_top2: 0.8583 - 14ms/step\n", "step 520/938 - loss: 1.5252 - acc_top1: 0.8196 - acc_top2: 0.8593 - 14ms/step\n", "step 530/938 - loss: 1.5265 - acc_top1: 0.8207 - acc_top2: 0.8601 - 14ms/step\n", "step 540/938 - loss: 1.5297 - acc_top1: 0.8217 - acc_top2: 0.8608 - 14ms/step\n", "step 550/938 - loss: 1.5743 - acc_top1: 0.8226 - acc_top2: 0.8613 - 13ms/step\n", "step 560/938 - loss: 1.6419 - acc_top1: 0.8237 - acc_top2: 0.8622 - 13ms/step\n", "step 570/938 - loss: 1.5556 - acc_top1: 0.8247 - acc_top2: 0.8630 - 13ms/step\n", "step 580/938 - loss: 1.5349 - acc_top1: 0.8254 - acc_top2: 0.8635 - 13ms/step\n", "step 590/938 - loss: 1.4915 - acc_top1: 0.8263 - acc_top2: 0.8640 - 13ms/step\n", "step 600/938 - loss: 1.5672 - acc_top1: 0.8277 - acc_top2: 0.8651 - 13ms/step\n", "step 610/938 - loss: 1.5464 - acc_top1: 0.8288 - acc_top2: 0.8659 - 13ms/step\n", "step 620/938 - loss: 1.6329 - acc_top1: 0.8292 - acc_top2: 0.8661 - 13ms/step\n", "step 630/938 - loss: 1.6121 - acc_top1: 0.8296 - acc_top2: 0.8662 - 13ms/step\n", "step 640/938 - loss: 1.5636 - acc_top1: 0.8305 - acc_top2: 0.8668 - 13ms/step\n", "step 650/938 - loss: 1.6227 - acc_top1: 0.8311 - acc_top2: 0.8672 - 13ms/step\n", "step 660/938 - loss: 1.5646 - acc_top1: 0.8319 - acc_top2: 0.8678 - 13ms/step\n", "step 670/938 - loss: 1.5620 - acc_top1: 0.8325 - acc_top2: 0.8681 - 13ms/step\n", "step 680/938 - loss: 1.4908 - acc_top1: 0.8333 - acc_top2: 0.8688 - 13ms/step\n", "step 690/938 - loss: 1.6010 - acc_top1: 0.8339 - acc_top2: 0.8691 - 13ms/step\n", "step 700/938 - loss: 1.5592 - acc_top1: 0.8346 - acc_top2: 0.8695 - 13ms/step\n", "step 710/938 - loss: 1.6226 - acc_top1: 0.8352 - acc_top2: 0.8699 - 13ms/step\n", "step 720/938 - loss: 1.5642 - acc_top1: 0.8362 - acc_top2: 0.8705 - 13ms/step\n", "step 730/938 - loss: 1.5807 - acc_top1: 0.8367 - acc_top2: 0.8707 - 13ms/step\n", "step 740/938 - loss: 1.5721 - acc_top1: 0.8371 - acc_top2: 0.8708 - 13ms/step\n", "step 750/938 - loss: 1.6542 - acc_top1: 0.8377 - acc_top2: 0.8711 - 13ms/step\n", "step 760/938 - loss: 1.5128 - acc_top1: 0.8385 - acc_top2: 0.8716 - 13ms/step\n", "step 770/938 - loss: 1.5711 - acc_top1: 0.8391 - acc_top2: 0.8721 - 14ms/step\n", "step 780/938 - loss: 1.6095 - acc_top1: 0.8395 - acc_top2: 0.8725 - 14ms/step\n", "step 790/938 - loss: 1.5348 - acc_top1: 0.8402 - acc_top2: 0.8730 - 14ms/step\n", "step 800/938 - loss: 1.5715 - acc_top1: 0.8407 - acc_top2: 0.8732 - 14ms/step\n", "step 810/938 - loss: 1.5880 - acc_top1: 0.8413 - acc_top2: 0.8737 - 14ms/step\n", "step 820/938 - loss: 1.6160 - acc_top1: 0.8418 - acc_top2: 0.8740 - 14ms/step\n", "step 830/938 - loss: 1.5585 - acc_top1: 0.8426 - acc_top2: 0.8746 - 14ms/step\n", "step 840/938 - loss: 1.5829 - acc_top1: 0.8429 - acc_top2: 0.8748 - 14ms/step\n", "step 850/938 - loss: 1.5348 - acc_top1: 0.8435 - acc_top2: 0.8753 - 14ms/step\n", "step 860/938 - loss: 1.5448 - acc_top1: 0.8438 - acc_top2: 0.8754 - 14ms/step\n", "step 870/938 - loss: 1.5463 - acc_top1: 0.8443 - acc_top2: 0.8759 - 14ms/step\n", "step 880/938 - loss: 1.5763 - acc_top1: 0.8449 - acc_top2: 0.8762 - 14ms/step\n", "step 890/938 - loss: 1.5699 - acc_top1: 0.8453 - acc_top2: 0.8764 - 14ms/step\n", "step 900/938 - loss: 1.5616 - acc_top1: 0.8456 - acc_top2: 0.8766 - 14ms/step\n", "step 910/938 - loss: 1.5026 - acc_top1: 0.8461 - acc_top2: 0.8771 - 14ms/step\n", "step 920/938 - loss: 1.5380 - acc_top1: 0.8467 - acc_top2: 0.8774 - 14ms/step\n", "step 930/938 - loss: 1.5993 - acc_top1: 0.8470 - acc_top2: 0.8777 - 14ms/step\n", "step 938/938 - loss: 1.4942 - acc_top1: 0.8473 - acc_top2: 0.8778 - 14ms/step\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n", "Epoch 2/2\n", "step 10/938 - loss: 1.5919 - acc_top1: 0.8875 - acc_top2: 0.9047 - 14ms/step\n", "step 20/938 - loss: 1.5900 - acc_top1: 0.8875 - acc_top2: 0.9062 - 14ms/step\n", "step 30/938 - loss: 1.5929 - acc_top1: 0.8891 - acc_top2: 0.9036 - 13ms/step\n", "step 40/938 - loss: 1.5855 - acc_top1: 0.8883 - acc_top2: 0.9027 - 13ms/step\n", "step 50/938 - loss: 1.6197 - acc_top1: 0.8916 - acc_top2: 0.9072 - 13ms/step\n", "step 60/938 - loss: 1.5084 - acc_top1: 0.8914 - acc_top2: 0.9078 - 13ms/step\n", "step 70/938 - loss: 1.5552 - acc_top1: 0.8904 - acc_top2: 0.9067 - 13ms/step\n", "step 80/938 - loss: 1.5700 - acc_top1: 0.8887 - acc_top2: 0.9049 - 13ms/step\n", "step 90/938 - loss: 1.6073 - acc_top1: 0.8866 - acc_top2: 0.9030 - 13ms/step\n", "step 100/938 - loss: 1.5754 - acc_top1: 0.8859 - acc_top2: 0.9022 - 13ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 110/938 - loss: 1.5484 - acc_top1: 0.8848 - acc_top2: 0.9017 - 14ms/step\n", "step 120/938 - loss: 1.5904 - acc_top1: 0.8840 - acc_top2: 0.9020 - 14ms/step\n", "step 130/938 - loss: 1.5108 - acc_top1: 0.8852 - acc_top2: 0.9025 - 14ms/step\n", "step 140/938 - loss: 1.6199 - acc_top1: 0.8840 - acc_top2: 0.9016 - 14ms/step\n", "step 150/938 - loss: 1.5337 - acc_top1: 0.8842 - acc_top2: 0.9019 - 13ms/step\n", "step 160/938 - loss: 1.6094 - acc_top1: 0.8846 - acc_top2: 0.9023 - 13ms/step\n", "step 170/938 - loss: 1.5653 - acc_top1: 0.8843 - acc_top2: 0.9019 - 13ms/step\n", "step 180/938 - loss: 1.5978 - acc_top1: 0.8835 - acc_top2: 0.9011 - 13ms/step\n", "step 190/938 - loss: 1.5950 - acc_top1: 0.8833 - acc_top2: 0.9012 - 13ms/step\n", "step 200/938 - loss: 1.6422 - acc_top1: 0.8828 - acc_top2: 0.9002 - 13ms/step\n", "step 210/938 - loss: 1.5752 - acc_top1: 0.8831 - acc_top2: 0.9004 - 13ms/step\n", "step 220/938 - loss: 1.6635 - acc_top1: 0.8832 - acc_top2: 0.9001 - 13ms/step\n", "step 230/938 - loss: 1.5726 - acc_top1: 0.8823 - acc_top2: 0.8991 - 13ms/step\n", "step 240/938 - loss: 1.5702 - acc_top1: 0.8814 - acc_top2: 0.8981 - 13ms/step\n", "step 250/938 - loss: 1.5748 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n", "step 260/938 - loss: 1.5589 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n", "step 270/938 - loss: 1.5902 - acc_top1: 0.8823 - acc_top2: 0.8988 - 14ms/step\n", "step 280/938 - loss: 1.5646 - acc_top1: 0.8817 - acc_top2: 0.8982 - 14ms/step\n", "step 290/938 - loss: 1.6280 - acc_top1: 0.8819 - acc_top2: 0.8985 - 14ms/step\n", "step 300/938 - loss: 1.5697 - acc_top1: 0.8815 - acc_top2: 0.8982 - 14ms/step\n", "step 310/938 - loss: 1.5540 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n", "step 320/938 - loss: 1.5598 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n", "step 330/938 - loss: 1.5498 - acc_top1: 0.8824 - acc_top2: 0.8991 - 14ms/step\n", "step 340/938 - loss: 1.6276 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n", "step 350/938 - loss: 1.5129 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n", "step 360/938 - loss: 1.6158 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n", "step 370/938 - loss: 1.5300 - acc_top1: 0.8820 - acc_top2: 0.8986 - 14ms/step\n", "step 380/938 - loss: 1.5718 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n", "step 390/938 - loss: 1.5898 - acc_top1: 0.8825 - acc_top2: 0.8990 - 14ms/step\n", "step 400/938 - loss: 1.5177 - acc_top1: 0.8834 - acc_top2: 0.9000 - 14ms/step\n", "step 410/938 - loss: 1.6493 - acc_top1: 0.8831 - acc_top2: 0.8997 - 14ms/step\n", "step 420/938 - loss: 1.5071 - acc_top1: 0.8838 - acc_top2: 0.9002 - 14ms/step\n", "step 430/938 - loss: 1.5982 - acc_top1: 0.8840 - acc_top2: 0.9002 - 14ms/step\n", "step 440/938 - loss: 1.5649 - acc_top1: 0.8841 - acc_top2: 0.9003 - 14ms/step\n", "step 450/938 - loss: 1.5555 - acc_top1: 0.8844 - acc_top2: 0.9005 - 14ms/step\n", "step 460/938 - loss: 1.5536 - acc_top1: 0.8845 - acc_top2: 0.9005 - 14ms/step\n", "step 470/938 - loss: 1.5401 - acc_top1: 0.8851 - acc_top2: 0.9011 - 14ms/step\n", "step 480/938 - loss: 1.5549 - acc_top1: 0.8854 - acc_top2: 0.9013 - 14ms/step\n", "step 490/938 - loss: 1.5596 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n", "step 500/938 - loss: 1.5059 - acc_top1: 0.8860 - acc_top2: 0.9018 - 14ms/step\n", "step 510/938 - loss: 1.6073 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n", "step 520/938 - loss: 1.5588 - acc_top1: 0.8857 - acc_top2: 0.9016 - 14ms/step\n", "step 530/938 - loss: 1.6165 - acc_top1: 0.8859 - acc_top2: 0.9019 - 14ms/step\n", "step 540/938 - loss: 1.5884 - acc_top1: 0.8862 - acc_top2: 0.9023 - 14ms/step\n", "step 550/938 - loss: 1.6552 - acc_top1: 0.8863 - acc_top2: 0.9027 - 14ms/step\n", "step 560/938 - loss: 1.5529 - acc_top1: 0.8867 - acc_top2: 0.9030 - 14ms/step\n", "step 570/938 - loss: 1.5441 - acc_top1: 0.8866 - acc_top2: 0.9029 - 14ms/step\n", "step 580/938 - loss: 1.5438 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n", "step 590/938 - loss: 1.5761 - acc_top1: 0.8868 - acc_top2: 0.9029 - 14ms/step\n", "step 600/938 - loss: 1.5384 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n", "step 610/938 - loss: 1.5858 - acc_top1: 0.8871 - acc_top2: 0.9032 - 14ms/step\n", "step 620/938 - loss: 1.5524 - acc_top1: 0.8872 - acc_top2: 0.9034 - 14ms/step\n", "step 630/938 - loss: 1.6182 - acc_top1: 0.8875 - acc_top2: 0.9035 - 14ms/step\n", "step 640/938 - loss: 1.6326 - acc_top1: 0.8877 - acc_top2: 0.9037 - 14ms/step\n", "step 650/938 - loss: 1.5871 - acc_top1: 0.8877 - acc_top2: 0.9035 - 14ms/step\n", "step 660/938 - loss: 1.5403 - acc_top1: 0.8877 - acc_top2: 0.9034 - 14ms/step\n", "step 670/938 - loss: 1.5539 - acc_top1: 0.8879 - acc_top2: 0.9035 - 14ms/step\n", "step 680/938 - loss: 1.4918 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n", "step 690/938 - loss: 1.6007 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n", "step 700/938 - loss: 1.5539 - acc_top1: 0.8883 - acc_top2: 0.9037 - 14ms/step\n", "step 710/938 - loss: 1.6036 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", "step 720/938 - loss: 1.5943 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", "step 730/938 - loss: 1.5714 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", "step 740/938 - loss: 1.5095 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", "step 750/938 - loss: 1.5069 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", "step 760/938 - loss: 1.5816 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", "step 770/938 - loss: 1.5855 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n", "step 780/938 - loss: 1.5599 - acc_top1: 0.8881 - acc_top2: 0.9034 - 14ms/step\n", "step 790/938 - loss: 1.6029 - acc_top1: 0.8879 - acc_top2: 0.9032 - 14ms/step\n", "step 800/938 - loss: 1.5839 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n", "step 810/938 - loss: 1.5545 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", "step 820/938 - loss: 1.5458 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n", "step 830/938 - loss: 1.5911 - acc_top1: 0.8879 - acc_top2: 0.9033 - 14ms/step\n", "step 840/938 - loss: 1.5845 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", "step 850/938 - loss: 1.5628 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n", "step 860/938 - loss: 1.5596 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n", "step 870/938 - loss: 1.5843 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n", "step 880/938 - loss: 1.5393 - acc_top1: 0.8883 - acc_top2: 0.9036 - 14ms/step\n", "step 890/938 - loss: 1.5382 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", "step 900/938 - loss: 1.5910 - acc_top1: 0.8884 - acc_top2: 0.9036 - 14ms/step\n", "step 910/938 - loss: 1.5682 - acc_top1: 0.8886 - acc_top2: 0.9038 - 14ms/step\n", "step 920/938 - loss: 1.5736 - acc_top1: 0.8889 - acc_top2: 0.9039 - 14ms/step\n", "step 930/938 - loss: 1.5283 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n", "step 938/938 - loss: 1.5582 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n" ] } ], "source": [ "model.fit(train_dataset,\n", " epochs=2,\n", " batch_size=64,\n", " save_dir='mnist_checkpoint')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 使用model.evaluate来预测模型" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eval begin...\n", "step 10/157 - loss: 1.5447 - acc_top1: 0.8953 - acc_top2: 0.9078 - 5ms/step\n", "step 20/157 - loss: 1.6185 - acc_top1: 0.8930 - acc_top2: 0.9078 - 5ms/step\n", "step 30/157 - loss: 1.6497 - acc_top1: 0.8917 - acc_top2: 0.9057 - 5ms/step\n", "step 40/157 - loss: 1.6318 - acc_top1: 0.8902 - acc_top2: 0.9055 - 5ms/step\n", "step 50/157 - loss: 1.5533 - acc_top1: 0.8856 - acc_top2: 0.9012 - 5ms/step\n", "step 60/157 - loss: 1.6212 - acc_top1: 0.8878 - acc_top2: 0.9036 - 5ms/step\n", "step 70/157 - loss: 1.5674 - acc_top1: 0.8839 - acc_top2: 0.9002 - 5ms/step\n", "step 80/157 - loss: 1.5409 - acc_top1: 0.8891 - acc_top2: 0.9043 - 5ms/step\n", "step 90/157 - loss: 1.6133 - acc_top1: 0.8903 - acc_top2: 0.9045 - 5ms/step\n", "step 100/157 - loss: 1.5535 - acc_top1: 0.8909 - acc_top2: 0.9044 - 5ms/step\n", "step 110/157 - loss: 1.5690 - acc_top1: 0.8916 - acc_top2: 0.9054 - 5ms/step\n", "step 120/157 - loss: 1.6147 - acc_top1: 0.8926 - acc_top2: 0.9055 - 5ms/step\n", "step 130/157 - loss: 1.5203 - acc_top1: 0.8944 - acc_top2: 0.9066 - 5ms/step\n", "step 140/157 - loss: 1.5066 - acc_top1: 0.8952 - acc_top2: 0.9068 - 5ms/step\n", "step 150/157 - loss: 1.5536 - acc_top1: 0.8958 - acc_top2: 0.9072 - 5ms/step\n", "step 157/157 - loss: 1.5855 - acc_top1: 0.8956 - acc_top2: 0.9076 - 5ms/step\n", "Eval samples: 10000\n" ] }, { "data": { "text/plain": [ "{'loss': [1.585474], 'acc_top1': 0.8956, 'acc_top2': 0.9076}" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(test_dataset, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 训练方式二结束\n", "以上就是训练方式二,可以快速、高效的完成网络模型训练与预测。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }