{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#
手写数字分类识别入门体验教程
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 实现一个图片分类应用\n", "## 概述\n", "下面我们通过一个实际样例,带领大家体验MindSpore基础的功能,对于一般的用户而言,完成整个样例实践会持续20~30分钟。\n", "\n", "本例子会实现一个简单的图片分类的功能,整体流程如下:\n", "\n", "1、处理需要的数据集,这里使用了MNIST数据集。\n", "\n", "2、定义一个网络,这里我们使用LeNet网络。\n", "\n", "3、定义损失函数和优化器。\n", "\n", "4、加载数据集并进行训练,训练完成后,查看结果及保存模型文件。\n", "\n", "5、加载保存的模型,进行推理。\n", "\n", "6、验证模型,加载测试数据集和训练后的模型,验证结果精度。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "说明:
你可以在这里找到完整可运行的样例代码:https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/lenet.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一、训练的数据集下载" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 方法一:\n", "从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:
训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "
测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}
我们用下面代码查询jupyter的工作目录。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.getcwd()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练数据集放在----Jupyter工作目录+\\MNIST_Data\\train\\,此时train文件夹内应该包含两个文件,train-images-idx3-ubyte和train-labels-idx1-ubyte
测试数据集放在----Jupyter工作目录+\\MNIST_Data\\test\\,此时test文件夹内应该包含两个文件,t10k-images-idx3-ubyte和t10k-labels-idx1-ubyte" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 方法二:\n", "直接执行下面代码,会自动进行训练集的下载与解压,但是整个过程根据网络好坏情况会需要花费几分钟时间。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Network request module, data download module, decompression module\n", "import urllib.request \n", "from urllib.parse import urlparse\n", "import gzip \n", "\n", "def unzipfile(gzip_path):\n", " \"\"\"unzip dataset file\n", " Args:\n", " gzip_path: dataset file path\n", " \"\"\"\n", " open_file = open(gzip_path.replace('.gz',''), 'wb')\n", " gz_file = gzip.GzipFile(gzip_path)\n", " open_file.write(gz_file.read())\n", " gz_file.close()\n", " \n", "def download_dataset():\n", " \"\"\"Download the dataset from http://yann.lecun.com/exdb/mnist/.\"\"\"\n", " print(\"******Downloading the MNIST dataset******\")\n", " train_path = \"./MNIST_Data/train/\" \n", " test_path = \"./MNIST_Data/test/\"\n", " train_path_check = os.path.exists(train_path)\n", " test_path_check = os.path.exists(test_path)\n", " if train_path_check == False and test_path_check ==False:\n", " os.makedirs(train_path)\n", " os.makedirs(test_path)\n", " train_url = {\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", " test_url = {\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n", " \n", " for url in train_url:\n", " url_parse = urlparse(url)\n", " # split the file name from url\n", " file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n", " if not os.path.exists(file_name.replace('.gz','')):\n", " file = urllib.request.urlretrieve(url, file_name)\n", " unzipfile(file_name)\n", " os.remove(file_name)\n", " \n", " for url in test_url:\n", " url_parse = urlparse(url)\n", " # split the file name from url\n", " file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n", " if not os.path.exists(file_name.replace('.gz','')):\n", " file = urllib.request.urlretrieve(url, file_name)\n", " unzipfile(file_name)\n", " os.remove(file_name)\n", "\n", "download_dataset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这样就完成了数据集的下载解压缩工作。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 二、处理MNIST数据集" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于我们后面会采用LeNet这样的卷积神经网络对数据集进行训练,而采用LeNet在训练数据时,对数据格式是有所要求的,所以接下来的工作需要我们先查看数据集内的数据是什么样的,这样才能构造一个针对性的数据转换函数,将数据集数据转换成符合训练要求的数据形式。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更多的LeNet网络的介绍不在此赘述,希望详细了解LeNet网络,可以查询http://yann.lecun.com/exdb/lenet/ 。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 查看原始数据集数据" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mindspore import context\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "context.set_context(mode=context.GRAPH_MODE,device_target=\"CPU\") # Windows version, set to use CPU for graph calculation\n", "train_data_path = \"./MNIST_Data/train\"\n", "test_data_path = \"./MNIST_Data/test\"\n", "mnist_ds = ds.MnistDataset(train_data_path) # Load training dataset\n", "print('The type of mnist_ds:',type(mnist_ds))\n", "print(\"Number of pictures contained in the mnist_ds:\",mnist_ds.get_dataset_size()) # 60000 pictures in total\n", "\n", "dic_ds = mnist_ds.create_dict_iterator() # Convert dataset to dictionary type\n", "item = dic_ds.get_next()\n", "img = item[\"image\"]\n", "label = item[\"label\"]\n", "\n", "print(\"The item of mnist_ds:\",item.keys()) # Take a single data to view the data structure, including two keys, image and label\n", "print(\"Tensor of image in item:\",img.shape) # View the tensor of image (28,28,1)\n", "print(\"The label of item:\",label)\n", "\n", "plt.imshow(np.squeeze(img))\n", "plt.title(\"number:%s\"%item[\"label\"])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的运行情况我们可以看到,训练数据集train-images-idx3-ubyte和train-labels-idx1-ubyte对应的是6万张图片和6万个数字下标,载入数据后经过create_dict_iterator()转换字典型的数据集,取其中的一个数据查看,这是一个key为image和label的字典,其中的image的张量(高度28,宽度28,通道1)和label为对应图片的数字。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 数据处理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "数据集对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。\n", "#### 定义数据集及数据操作\n", "我们定义一个函数create_dataset()来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n", "
1、定义数据集。\n", "
2、定义进行数据增强和处理所需要的一些参数。\n", "
3、根据参数,生成对应的数据增强操作。\n", "
4、使用map()映射函数,将数据操作应用到数据集。\n", "
5、对生成的数据集进行处理。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Data processing module\n", "import mindspore.dataset.transforms.vision.c_transforms as CV\n", "import mindspore.dataset.transforms.c_transforms as C\n", "from mindspore.dataset.transforms.vision import Inter\n", "from mindspore.common import dtype as mstype\n", "\n", "\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n", " num_parallel_workers=1):\n", " \"\"\" create dataset for train or test\n", " Args:\n", " data_path: Data path\n", " batch_size: The number of data records in each group\n", " repeat_size: The number of replicated data records\n", " num_parallel_workers: The number of parallel workers\n", " \"\"\"\n", " # define dataset\n", " mnist_ds = ds.MnistDataset(data_path)\n", "\n", " # Define some parameters needed for data enhancement and rough justification\n", " resize_height, resize_width = 32, 32\n", " rescale = 1.0 / 255.0\n", " shift = 0.0\n", " rescale_nml = 1 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n", "\n", " # According to the parameters, generate the corresponding data enhancement method\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) by bilinear interpolation\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n", " rescale_op = CV.Rescale(rescale, shift) # rescale images\n", " hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n", " type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n", "\n", " # Using map () to apply operations to a dataset\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " # Process the generated dataset\n", " buffer_size = 10000\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n", "\n", " return mnist_ds\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中
\n", "batch_size:每组包含的数据个数,现设置每组包含32个数据。\n", "
repeat_size:数据集复制的数量。\n", "
先进行shuffle、batch操作,再进行repeat操作,这样能保证1个epoch内数据不重复。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来我们查看将要进行训练的数据集内容是什么样的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先,查看数据集内包含多少组数据。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datas = create_dataset(train_data_path) # Process the train dataset\n", "print('Number of groups in the dataset:',datas.get_dataset_size()) # Number of query dataset groups" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其次,取出其中一组数据,查看包含的key,图片数据的张量,以及下标labels的值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = datas.create_dict_iterator().get_next() # Take a set of datasets\n", "print(data.keys())\n", "images = data[\"image\"] # Take out the image data in this dataset\n", "labels = data[\"label\"] # Take out the label (subscript) of this data set\n", "print('Tensor of image:',images.shape) # Query the tensor of images in each dataset (32,1,32,32)\n", "print('labels:',labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后,查看image的图像和下标对应的值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "count = 1\n", "for i in images:\n", " plt.subplot(4,8,count) \n", " plt.imshow(np.squeeze(i))\n", " plt.title('num:%s'%labels[count-1])\n", " plt.xticks([])\n", " count+=1\n", " plt.axis(\"off\")\n", "plt.show() # Print a total of 32 pictures in the group" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过上述三个查询操作,看到经过变换后的图片,数据集内分成了1875组数据,每组数据中含有32张图片,每张图片像数值为32×32,数据全部准备好后,就可以进行下一步的数据训练了。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 三、构造神经网络" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在对手写字体识别上,通常采用卷积神经网络架构(CNN)进行学习预测,最经典的属1998年由Yann LeCun创建的LeNet5架构,
其中分为:
1、输入层;
2、卷积层C1;
3、池化层S2;
4、卷积层C3;
5、池化层S4;
6、全连接F6;
7、全连接;
8、全连接OUTPUT。
结构示意如下图:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LeNet5结构图" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"LeNet5\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在构建LeNet5前,我们需要对全连接层以及卷积层进行初始化。\n", "\n", "TruncatedNormal:参数初始化方法,MindSpore支持TruncatedNormal、Normal、Uniform等多种参数初始化方法,具体可以参考MindSpore API的mindspore.common.initializer模块说明。\n", "\n", "初始化示例代码如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import mindspore.nn as nn\n", "from mindspore.common.initializer import TruncatedNormal\n", "\n", "# Initialize 2D convolution function\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", " \"\"\"Conv layer weight initial.\"\"\"\n", " weight = weight_variable()\n", " return nn.Conv2d(in_channels, out_channels,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", "\n", "# Initialize full connection layer\n", "def fc_with_initialize(input_channels, out_channels):\n", " \"\"\"Fc layer weight initial.\"\"\"\n", " weight = weight_variable()\n", " bias = weight_variable()\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n", "\n", "# Set truncated normal distribution\n", "def weight_variable():\n", " \"\"\"Weight initial.\"\"\"\n", " return TruncatedNormal(0.02)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用MindSpore定义神经网络需要继承mindspore.nn.cell.Cell。Cell是所有神经网络(Conv2d等)的基类。\n", "\n", "神经网络的各层需要预先在\\_\\_init\\_\\_()方法中定义,然后通过定义construct()方法来完成神经网络的前向构造。按照LeNet5的网络结构,定义网络各层如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LeNet5(nn.Cell):\n", " \"\"\"Lenet network structure.\"\"\"\n", " # define the operator required\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " self.batch_size = 32 # 32 pictures in each group\n", " self.conv1 = conv(1, 6, 5) # Convolution layer 1, 1 channel input (1 Figure), 6 channel output (6 figures), convolution core 5 * 5\n", " self.conv2 = conv(6, 16, 5) # Convolution layer 2,6-channel input, 16 channel output, convolution kernel 5 * 5\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc2 = fc_with_initialize(120, 84)\n", " self.fc3 = fc_with_initialize(84, 10)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.flatten = nn.Flatten()\n", "\n", " # use the preceding operators to construct networks\n", " def construct(self, x):\n", " x = self.conv1(x) # 1*32*32-->6*28*28\n", " x = self.relu(x) # 6*28*28-->6*14*14\n", " x = self.max_pool2d(x) # Pool layer\n", " x = self.conv2(x) # Convolution layer\n", " x = self.relu(x) # Function excitation layer\n", " x = self.max_pool2d(x) # Pool layer\n", " x = self.flatten(x) # Dimensionality reduction\n", " x = self.fc1(x) # Full connection\n", " x = self.relu(x) # Function excitation layer\n", " x = self.fc2(x) # Full connection\n", " x = self.relu(x) # Function excitation layer\n", " x = self.fc3(x) # Full connection\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "构建完成后,我们将LeNet5的整体参数打印出来查看一下。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "network = LeNet5()\n", "print(network)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param = network.trainable_params()\n", "param" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 四、搭建训练网络并进行训练" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "构建完成神经网络后,就可以着手进行训练网络的构建,模型训练函数为Model.train(),参数主要包含:\n", "
1、圈数epoch size(每圈需要遍历完成1875组图片);\n", "
2、数据集ds_train;\n", "
3、回调函数callbacks包含ModelCheckpoint、LossMonitor、SummaryStepckpoint_cb,Callback模型检测参数;\n", "
4、底层数据通道dataset_sink_mode,此参数默认True需设置成False,因为此功能只限于昇腾AI处理器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Training and testing related modules\n", "import argparse\n", "from mindspore import Tensor\n", "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor,SummaryStep,Callback\n", "from mindspore.train import Model\n", "from mindspore.nn.metrics import Accuracy\n", "from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n", "\n", "def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb,lmf_info):\n", " \"\"\"Define the training method.\"\"\"\n", " print(\"============== Starting Training ==============\")\n", " # load training dataset\n", " ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n", " model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(),lmf_info], dataset_sink_mode=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "自定义一个存储每一步训练的step和对应loss值的回调LMF_info函数,本函数继承了Callback类,可以自定义训练过程中的处理措施,非常方便,等训练完成后,可将数据绘图查看loss的变化情况。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Custom callback function\n", "class LMF_info(Callback):\n", " def step_end(self, run_context):\n", " cb_params = run_context.original_args()\n", " # step_ Loss dictionary for saving loss value and step number information\n", " step_loss[\"loss_value\"].append(str(cb_params.net_outputs))\n", " step_loss[\"step\"].append(str(cb_params.cur_step_num))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义损失函数及优化器\n", "基本概念\n", "在进行定义之前,先简单介绍损失函数及优化器的概念。\n", "
损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失函数的值。定义一个好的损失函数,可以有效提高模型的性能。\n", "
优化器:用于最小化损失函数,从而在训练过程中改进模型。\n", "
定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。\n", "
定义损失函数。\n", "
MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。这里使用SoftmaxCrossEntropyWithLogits损失函数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import os\n", "\n", "os.system('del/f/s/q *.ckpt *.meta')# Clean up old run files before\n", "lr = 0.01 # learning rate\n", "momentum = 0.9 #\n", "\n", "# create the network\n", "network = LeNet5()\n", "\n", "# define the optimizer\n", "net_opt = nn.Momentum(network.trainable_params(), lr, momentum)\n", "\n", "\n", "# define the loss function\n", "net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", "# define the model\n", "model = Model(network, net_loss, net_opt,metrics={\"Accuracy\":Accuracy()} )#metrics={\"Accuracy\": Accuracy()}\n", "\n", "\n", "epoch_size = 1\n", "mnist_path = \"./MNIST_Data\"\n", "\n", "config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=16)\n", "# save the network model and parameters for subsequence fine-tuning\n", "\n", "ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)\n", "# group layers into an object with training and evaluation features\n", "step_loss = {\"step\":[],\"loss_value\":[]}\n", "# step_ Loss dictionary for saving loss value and step number information\n", "lmf_info=LMF_info()\n", "# save the steps and loss value\n", "repeat_size = 1\n", "train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb,lmf_info)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练完成后,能在Jupyter的工作路径上生成多个模型文件,名称具体含义checkpoint_{网络名称}-{第几个epoch}_{第几个step}.ckpt 。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 查看损失函数随着训练步数的变化情况" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "steps=step_loss[\"step\"]\n", "loss_value = step_loss[\"loss_value\"]\n", "steps = list(map(int,steps))\n", "loss_value = list(map(float,loss_value))\n", "plt.plot(steps,loss_value,color=\"red\")\n", "plt.xlabel(\"Steps\")\n", "plt.ylabel(\"Loss_value\")\n", "plt.title(\"Loss function value change chart\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面可以看出来大致分为三个阶段:\n", "\n", "阶段一:开始训练loss值在2.2上下浮动,训练收益感觉并不明显。\n", "\n", "阶段二:训练到某一时刻,loss值减少迅速,训练收益大幅增加。\n", "\n", "阶段三:loss值收敛到一定小的值后,loss值开始振荡在一个小的区间上无法趋0,再继续增加训练并无明显收益,至此训练结束。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 五、数据测试验证模型精度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "搭建测试网络的过程主要为:
1、载入模型.cptk文件中的参数param;
2、将参数param载入到神经网络LeNet5中;
3、载入测试数据集;
4、调用函数model.eval()传入参数测试数据集ds_eval,就生成模型checkpoint_lenet-1_1875.ckpt的精度值。
dataset_sink_mode表示数据集下沉模式,仅仅支持昇腾AI处理器平台,所以这里设置成False 。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_net(network, model, mnist_path):\n", " \"\"\"Define the evaluation method.\"\"\"\n", " print(\"============== Starting Testing ==============\")\n", " # load the saved model for evaluation\n", " param_dict = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\")\n", " # load parameter to the network\n", " load_param_into_net(network, param_dict)\n", " # load testing dataset\n", " ds_eval = create_dataset(os.path.join(mnist_path, \"test\"))\n", " acc = model.eval(ds_eval, dataset_sink_mode=False)\n", " print(\"============== Accuracy:{} ==============\".format(acc))\n", "\n", "test_net(network, model, mnist_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过1875步训练后生成的模型精度超过95%,模型优良。\n", "我们可以看一下模型随着训练步数变化,精度随之变化的情况。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "acc_model_info()函数是将每125步的保存的模型,调用model.eval()函数将测试出的精度返回到步数列表和精度列表,如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def acc_model_info(network, model, mnist_path, model_numbers):\n", " \"\"\"Define the plot info method\"\"\"\n", " step_list=[]\n", " acc_list =[]\n", " for i in range(1,model_numbers+1):\n", " # load the saved model for evaluation\n", " param_dict = load_checkpoint(\"checkpoint_lenet-1_{}.ckpt\".format(str(i*125)))\n", " # load parameter to the network\n", " load_param_into_net(network, param_dict)\n", " # load testing dataset\n", " ds_eval = create_dataset(os.path.join(mnist_path, \"test\"))\n", " acc = model.eval(ds_eval, dataset_sink_mode=False)\n", " acc_list.append(acc['Accuracy'])\n", " step_list.append(i*125)\n", " return step_list,acc_list\n", "\n", "# Draw line chart according to training steps and model accuracy\n", "l1,l2 = acc_model_info(network, model, mnist_path,15)\n", "plt.xlabel(\"Model of Steps\")\n", "plt.ylabel(\"Model accuracy\")\n", "plt.title(\"Model accuracy variation chart\")\n", "plt.plot(l1,l2,'red')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从图中可以看出训练得到的模型精度变化分为三个阶段:1、缓慢上升,2、迅速上升,3、缓慢上升趋近于不到1的某个值时附近振荡,说明随着训练数据的增加,会对模型精度有着正相关的影响,但是随着精度到达一定程度,训练收益会降低。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 六、模型预测应用" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们尝试使用生成的模型应用到分类预测单个或者单组图片数据上,具体步骤如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1、需要将要测试的数据转换成适应LeNet5的数据类型。\n", "
2、提取出image的数据。\n", "
3、使用函数model.predict()预测image对应的数字。需要说明的是predict返回的是image对应0-9的概率值。\n", "
4、调用plot_pie()将预测的各数字的概率显示出来。负概率的数字会被去掉。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "载入要测试的数据集并调用create_dataset()转换成符合格式要求的数据集,并选取其中一组32张图片进行预测。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_test = create_dataset(test_data_path).create_dict_iterator()\n", "data = ds_test.get_next()\n", "images = data[\"image\"]\n", "labels = data[\"label\"] # The subscript of data picture is the standard for us to judge whether it is correct or not\n", "\n", "output =model.predict(Tensor(data['image']))\n", "# The predict function returns the probability of 0-9 numbers corresponding to each picture\n", "prb = output.asnumpy()\n", "pred = np.argmax(output.asnumpy(),axis=1)\n", "err_num = []\n", "index = 1\n", "for i in range(len(labels)):\n", " plt.subplot(4,8,i+1)\n", " color = 'blue' if pred[i]==labels[i] else 'red'\n", " plt.title(\"pre:{}\".format(pred[i]),color = color)\n", " plt.imshow(np.squeeze(images[i]))\n", " plt.axis(\"off\")\n", " if color =='red':\n", " index=0\n", " # Print out the wrong data identified by the current group\n", " print(\"Row {}, column {} is incorrectly identified as {}, the correct value should be {}\".format(int(i/8)+1,i%8+1,pred[i],labels[i]),'\\n')\n", "if index:\n", " print(\"All the figures in this group are predicted correctly!\")\n", "print(pred,\"<--Predicted figures\") # Print the numbers recognized by each group of pictures\n", "print(labels,\"<--The right number\") # Print the subscript corresponding to each group of pictures\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "构建一个概率分析的饼图函数。\n", "\n", "备注:prb为上一段代码中,存储这组数对应的数字概率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define the pie drawing function of probability analysis\n", "def plot_pie(prbs):\n", " dict1={}\n", " # Remove the negative number and build the dictionary dict1. The key is the number and the value is the probability value\n", " for i in range(10):\n", " if prbs[i]>0:\n", " dict1[str(i)]=prbs[i]\n", "\n", " label_list = dict1.keys() # Label of each part\n", " size = dict1.values() # Size of each part\n", " colors = [\"red\", \"green\",\"pink\",\"blue\",\"purple\",\"orange\",\"gray\"] # Building a round cake pigment Library\n", " color = colors[:len(size)]# Color of each part\n", " plt.pie(size, colors=color, labels=label_list, labeldistance=1.1, autopct=\"%1.1f%%\", shadow=False, startangle=90, pctdistance=0.6)\n", " plt.axis(\"equal\") # Set the scale size of x-axis and y-axis to be equal\n", " plt.legend()\n", " plt.title(\"Image classification\")\n", " plt.show()\n", " \n", " \n", "for i in range(2):\n", " print(\"Figure {} probability of corresponding numbers [0-9]:\\n\".format(i+1),prb[i])\n", " plot_pie(prb[i])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上过程就是这次手写数字分类训练的全部体验过程。" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:root] *", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }