1-LeNet5_MNIST.ipynb 25.1 KB
Notebook
Newer Older
D
dyonghan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 style=\"text-align:center\">基于LeNet5的手写数字识别</h1>\n",
    "\n",
    "## 实验介绍\n",
    "\n",
    "LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。\n",
    "\n",
    "## 实验目的\n",
    "\n",
    "- 了解如何使用MindSpore进行简单卷积神经网络的开发。\n",
    "- 了解如何使用MindSpore进行简单图片分类任务的训练。\n",
    "- 了解如何使用MindSpore进行简单图片分类任务的验证。\n",
    "\n",
    "## 预备知识\n",
    "\n",
    "- 熟练使用Python,了解Shell及Linux操作系统基本知识。\n",
    "- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。\n",
23
    "- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等服务。华为云官网:https://www.huaweicloud.com\n",
D
dyonghan 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36
    "- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn\n",
    "\n",
    "## 实验环境\n",
    "\n",
    "- MindSpore 0.2.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);\n",
    "- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html\n",
    "\n",
    "## 实验准备\n",
    "\n",
    "### 创建OBS桶\n",
    "\n",
    "本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。\n",
    "\n",
N
njzheng 已提交
37
    "> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。\n",
D
dyonghan 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    "\n",
    "创建OBS桶的参考配置如下:\n",
    "\n",
    "- 区域:华北-北京四\n",
    "- 数据冗余存储策略:单AZ存储\n",
    "- 桶名称:如ms-course\n",
    "- 存储类别:标准存储\n",
    "- 桶策略:公共读\n",
    "- 归档数据直读:关闭\n",
    "- 企业项目、标签等配置:免\n",
    "\n",
    "### 数据集准备\n",
    "\n",
    "MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。\n",
    "\n",
    "从MNIST官网下载如下4个文件到本地并解压:\n",
    "\n",
    "```\n",
    "train-images-idx3-ubyte.gz:  training set images (9912422 bytes)\n",
    "train-labels-idx1-ubyte.gz:  training set labels (28881 bytes)\n",
    "t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes)\n",
    "t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)\n",
    "```\n",
    "\n",
    "### 脚本准备\n",
    "\n",
    "从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。\n",
    "\n",
    "### 上传文件\n",
    "\n",
    "将脚本和数据集上传到OBS桶中,组织为如下形式:\n",
    "\n",
    "```\n",
    "experiment_1\n",
    "├── MNIST\n",
    "│   ├── test\n",
    "│   │   ├── t10k-images-idx3-ubyte\n",
    "│   │   └── t10k-labels-idx1-ubyte\n",
    "│   └── train\n",
    "│       ├── train-images-idx3-ubyte\n",
    "│       └── train-labels-idx1-ubyte\n",
79 80
    "├── *.ipynb\n",
    "└── main.py\n",
D
dyonghan 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    "```\n",
    "\n",
    "## 实验步骤(方案一)\n",
    "\n",
    "### 创建Notebook\n",
    "\n",
    "可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。\n",
    "\n",
    "创建Notebook的参考配置:\n",
    "\n",
    "- 计费模式:按需计费\n",
    "- 名称:experiment_1\n",
    "- 工作环境:Python3\n",
    "- 资源池:公共资源\n",
    "- 类型:Ascend\n",
    "- 规格:单卡1*Ascend 910\n",
    "- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的experiment_1文件夹\n",
    "- 自动停止等配置:默认\n",
    "\n",
    "> **注意:**\n",
    "> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n",
    "> - 打开Notebook后,选择MindSpore环境作为Kernel。\n",
    "\n",
N
njzheng 已提交
104
    "> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n",
D
dyonghan 已提交
105
    "\n",
N
njzheng 已提交
106
    "> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n",
D
dyonghan 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    "\n",
    "导入MindSpore模块和辅助模块:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ['DEVICE_ID'] = '0'\n",
    "import matplotlib.pyplot as plt\n",
    "import mindspore as ms\n",
    "import mindspore.context as context\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "import mindspore.dataset.transforms.vision.c_transforms as CV\n",
    "\n",
D
dyonghan 已提交
125
    "from mindspore import nn\n",
126
    "from mindspore.model_zoo.lenet import LeNet5\n",
D
dyonghan 已提交
127
    "from mindspore.train import Model\n",
D
dyonghan 已提交
128
    "from mindspore.train.callback import LossMonitor\n",
D
dyonghan 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理\n",
    "\n",
    "在使用数据集训练网络前,首先需要对数据进行预处理,如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR_TRAIN = \"MNIST/train\" # 训练集信息\n",
    "DATA_DIR_TEST = \"MNIST/test\" # 测试集信息\n",
    "\n",
    "def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),\n",
    "                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):\n",
    "    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n",
154
    "    \n",
D
dyonghan 已提交
155
    "    ds = ds.map(input_columns=\"image\", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])\n",
D
dyonghan 已提交
156
    "    ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n",
D
dyonghan 已提交
157
    "    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)\n",
D
dyonghan 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    "    \n",
    "    return ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAD7CAYAAAAVQzPHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcm0lEQVR4nO3deZRV1Zk28OepQWaBYrIQAkZBIKyICjjE1U3aEDHdaU268QuiTRxCVqKt+aJGErOiMdqxTaL9pfOZDh0ZooKxo+0QtQnNEhLRBis4oSggDhArTIIWU0FVvf3HPexzCupW3enc4ezntxar3nvGXfCy795n2JtmBhGRpKsqdQFERIpBlZ2IeEGVnYh4QZWdiHhBlZ2IeEGVnYh4QZVdBkguIHlbqcshUmg+5XZFVnYk3yG5lWSvyLIrSS4vYbEKiuRnSK4huZfkZpIXlbpMEr+k5zbJO4N8/ojkuyRvKta5K7KyC9QAuLbUhcgWyeoMthkHYBGAmwD0BTABwB9jLpqUj8TmNoB7AYwxs2MBnA3gYpJfjLdkKZVc2f0IwPUk+x25guRIkkayJrJsOckrg/jLJFeSvJvkbpKbSJ4dLN9MchvJWUccdiDJpSSbSK4gOSJy7DHBug9IvhlthQXdhJ+TfIrkXgCfzuB3+y6AX5jZ02bWYmY7zeytLP9+pHIlNrfN7E0z2xtZ1AbgpIz/ZvJQyZVdA4DlAK7Pcf8zALwCYABSragHAUxC6i/+EgA/I9k7sv1MAD8AMBDASwAeAICgu7E0OMZgADMA3EPyE5F9LwZwO4A+AJ4leTHJVzop25nBsV8l2UjyfpJ1Of6eUnmSnNsgOYfkHgBbAPQKjh+7Sq7sAOB7AP6R5KAc9n3bzOabWSuAXwMYDuBWM2s2s98BOIj23zhPmtnvzawZqe7lWSSHA/gbAO8Ex2oxszUAHgbw95F9HzOzlWbWZmYHzGyRmX2yk7INA3ApgL8DMApADwD/msPvKJUrqbkNM7sDqcrxNAD3Afgwh98xaxVd2ZnZWgC/BTAnh923RuL9wfGOXBb99tscOe8eAB8AGApgBIAzgi7DbpK7kfqmPK6jfTO0H8B8M1sfnOufAHwuy2NIBUtwbh8+j5nZi0FZvp/LMbJV0/UmZe9mAGsA/CSy7PA1gZ4APgri6D9QLoYfDoIuQB2A95H6x15hZlM72TfboWVeyWEfSZ4k5vaRagCcmOcxMlLRLTsAMLONSDXVr4ks2w7gTwAuIVlN8nLk/xf6OZLnkDwGqesbq8xsM1LfvqNJXkqyNvgzieTYPM41H8BlJD9OsieAG4PziEeSltskq0h+lWR/pkwGcBWAZXmWPyMVX9kFbkXqQmfUVwDcAGAngE8AeC7PcyxC6pv2AwCnI9Wch5k1AfgsgC8h9W34ZwD/DKBbugORnEnytXTrzWwegF8BWAXgXQDNiCS8eCVRuQ3gCwDeAtAE4H6krkUX5Xo0NXiniPggKS07EZFOqbITES+oshMRL+RV2ZGcFrxCspFkLs8DiZQl5Xby5HyDgqmXftcDmIrUax8vAJhhZq8XrngixafcTqZ8HiqeDGCjmW0CAJIPArgAQNqEOIbdrPtRd9GlFJqwa4eZ5fIqkg+U2xXqAPbioDWzo3X5VHbHo/2rIluQegG5HZKzAcwGgO7oiTN4bh6nlEL5b/vNu6UuQxlTbleoVZb++eR8rtl1VHse1Sc2s7lmNtHMJtamfxZRpJwotxMon8puCyLv1CE1Usf7+RVHpCwotxMon8ruBQCjSJ4QvFP3JQCPF6ZYIiWl3E6gnK/ZmVkLyasBLAFQDWCemXX2TpxIRVBuJ1NeQzyZ2VMAnipQWUTKhnI7efQGhYh4QZWdiHghCSMVl43G6852cdO4g1ntywPhLHQnzwmfXW1rasq/YCKilp2I+EGVnYh4Qd3YAqo7L3zu9JXxj2a177qD+1x83a0XhivUjfVaVc+eLn7v2gkubiuTFzYGvdTi4h6Pri5hSbqmlp2IeEGVnYh4Qd3YPO37YjgYxuQBL5SwJJIU1YPCkbca/88oFy/52p0uHlbTG+Vg6rrPu/jg/okuPmZJQymK0ym17ETEC6rsRMQLquxExAu6ZpdGVZ8+Lm457aS02116+xMunt03uyHPtrXudfFdWz8brmhp6WBr8cWhccNc/OJ37omsKY/rdFFLx4b5f8nNU1y89cCpLmZLOO5p1aq1LrYi57ladiLiBVV2IuIFdWMj2C18LH3flLEuXvGLubGcb+6u01383hl7I2v2Hr2xSJm7f+Ty8MPiMF5/KMznb5x/WbjNzt0ubNv9YbtjWXNzoYunlp2I+EGVnYh4Qd3YiN3TwztIi2//cWRN+d0FE6kUJ9b0cPG/PD3fxa0Wzlh52Xe/2W6fvvf/T8HLoZadiHhBlZ2IeMH7buy2q8Oh1L99zQMuPqG2cF3XU1bPcPHgu7u7uHrvochWayECADVrNrr4L78628ULfnaXizPJz3R5l6np/7bExdk+MB9VzbBNNbq2V4fbtNayw+WF1GXLjuQ8kttIro0sqyO5lOSG4Gf/eIspUnjKbb9k0o1dAGDaEcvmAFhmZqMALAs+i1SaBVBue6PLbqyZ/Z7kyCMWXwBgShAvBLAcwI0FLFfRHBgYxhf1/jD9hnn4aGfYdD9uRTjOl3W0sRRNueZ2dEa5Hr972cUzbrrexZl0+4as3+9irnyxw22q+/V18Z8WDG23bkrPDZFPHXc/03loT3jcH/50pouf+FbpxuTL9QbFEDNrBIDg5+B0G5KcTbKBZMMhFP6paJECU24nVOx3Y81srplNNLOJtSiTWUJECkC5XVlyvRu7lWS9mTWSrAewrZCFitsHl53l4jOnvRrLOaasDWcIG/6EnvCpIGWV29F3RPN50NY+Fc5MtuHy8L99VbdWFzec/v/b7dO/Oruua9T6A/UuPm7eSy4+r8+3XBydIW3kml3t9m/L+czp5fq/8HEAs4J4FoDHClMckZJTbidUJo+eLAbwPICTSW4heQWAOwBMJbkBwNTgs0hFUW77JZO7sTPSrDq3wGWJ1e5/CLuu478aPsA7/2N/iP3c2yeEf83HHhuWo9+vno/93JJeUnI7ndYpp7l4y9fDB9jfPmdBmj16plmemQeaBrj4vsc/7eKR+8I8H/bD5zrcN45u65F0MUlEvKDKTkS84M27sT0uaXRxMbquy8c/Gn4YH4bXNYZdi9V7wgm2ez6yKvYySTIdPC+cnHrP0FoXN52/x8VvnHNf7OVYs2eEiwevKUbHNDtq2YmIF1TZiYgXvOnGlouf1K9x8W237HPxHx7JfggeEQCw63a4+IXo5ZMii+b23Nv/7OLfROZE1ryxIiIxU2UnIl5IdDe2ekCdi7vXHOpky9LoVhWWqXrQcBe37gi7JTANBCWVJzqy8ZRF4Tu30Xlj29ZvcnExurRq2YmIF1TZiYgXEt2NPf6pcHicu49/KrKmPO58XtP/DRePem6ri+eeHb4/27p9e1HLJFJo6eaNvfriq1zMlS8hbmrZiYgXVNmJiBdU2YmIFxJ9zW5Ej50u7l3V9XW6S96Z4uLXfzXWxWu+9/OCluuwbgxf2j6/Zzgs9YZnwlvyy74cXr+zBk2kLUfr8c0wt0+5LRyi7+XJi0tRnKOkmyTbasIZ0uKfIlstOxHxhCo7EfFCIrqx6Sb6vajvLyNbdT1T0pY9/Vx83MMbXTyp+Wtp97nh24vC8+UxyXa0S3tD3Vsu/l2vv3CxvpmkI21rw0eYhvwonEVs0uj0eZutv7g6HG8x+sJ/JdH/HxHxgio7EfFCIrqx6BbOtvvghHtdHL3zk63omwt189O/xfDDXjNdfPPAcHl08u1iDAMvArR/E6FuZeGO+8cZI8MPSe3GkhxO8hmS60i+RvLaYHkdyaUkNwQ/+8dfXJHCUW77JZNubAuA68xsLIAzAVxFchyAOQCWmdkoAMuCzyKVRLntkUwmyW4E0BjETSTXATgewAUApgSbLQSwHMCNsZSySP56aNj1nHfLeSUsiRRDueR21Slj231+5wvl15C8fOiSUhchb1ndoCA5EsCpAFYBGBIky+GkGZxmn9kkG0g2HEJzR5uIlJxyO/kyruxI9gbwMIBvmNlHme5nZnPNbKKZTaxFt653ECky5bYfMrobS7IWqWR4wMweCRZvJVlvZo0k6wFsi6uQxRJ9mPeG2feUsCRSLOWQ2ztO69fu8zrlXiwyuRtLAPcCWGdmd0VWPQ5gVhDPAvBY4YsnEh/ltl8yadl9CsClAF4lefghnu8AuAPAQySvAPAegOnxFFEkNsptj2RyN/ZZpB+B5dzCFkekeEqZ2zXDh7m4aWQxBjgSvS4mIl5QZSciXkjGu7Ft4UTSbx6KPhIV3kQbXhPW65mMWlwMzRZOkr3pUMeTeLNFk2Qn0aYrPubiN76SvLuv5ZjbatmJiBdU2YmIFxLRjW3dscPF0QmmURXe5WpbHI4E/F9jnixKubry011jXPzMuSd2uE3VznCSHXVopVKUY26rZSciXlBlJyJeSEQ3FhY2gqMjDEc1t4wsUmGOdsrqcC7PwXeHd4Kr94Z3qWyr5oT1ycfvfc/FY/j1dusq9e5sdN7lHdeED02XS26rZSciXlBlJyJeSEY3NgP8STgbzqShhZtPMxND1u8Py7HyRRfr7qq/WjZvcfGJC9v/NxyDsFtb7l3aqes+7+KWO4e4+JiGhlIUp1Nq2YmIF1TZiYgXVNmJiBe8uWZ3zJLwGkJdCcshcqSWTe+0+3ziL1tcPNa+jnI26KWwrD2WrC5hSbqmlp2IeEGVnYh4wZturEiliD6W8rFbtnSypWRDLTsR8YIqOxHxQibzxnYnuZrkyyRfI/n9YHkdyaUkNwQ/+8dfXJHCUW77JZOWXTOAvzKzUwBMADCN5JkA5gBYZmajACwLPotUEuW2R7qs7CxlT/CxNvhjAC4AsDBYvhDAhbGUUCQmym2/ZHTNjmR1MGP6NgBLzWwVgCFm1ggAwc/BnR1DpBwpt/2RUWVnZq1mNgHAMACTSY7P9AQkZ5NsINlwCM25llMkFsptf2R1N9bMdgNYDmAagK0k6wEg+LktzT5zzWyimU2sRbc8iysSD+V28mVyN3YQyX5B3APAZwC8AeBxALOCzWYBeCyuQorEQbntl0zeoKgHsJBkNVKV40Nm9luSzwN4iOQVAN4DMD3GcorEQbntkS4rOzN7BcCpHSzfCeDcOAolUgzKbb/QrHiDg5PcDuDdop1QOjPCzAaVuhBJodwuG2nzuqiVnYhIqejdWBHxgio7EfGCKrsMkFxA8rZSl0Ok0HzK7Yqs7Ei+Q3IryV6RZVeSXF7CYhUMyTtJbib5Ecl3Sd5U6jJJcSQ9twGA5GdIriG5N8jzi4px3oqs7AI1AK4tdSGyFTzT1ZV7AYwxs2MBnA3gYpJfjLdkUkYSm9skxwFYBOAmAH2RGm3mjzEXDUBlV3Y/AnD94Sfgo0iOJGkkayLLlpO8Moi/THIlybtJ7ia5ieTZwfLNJLeRnHXEYQcGY5s1kVxBckTk2GOCdR+QfDP6TRV0E35O8imSewF8uqtfzMzeNLO9kUVtAE7K+G9GKl1icxvAdwH8wsyeNrMWM9tpZm9l+feTk0qu7BqQepfx+hz3PwPAKwAGIPVN8yCASUhVKpcA+BnJ3pHtZwL4AYCBAF4C8AAABN2NpcExBgOYAeAekp+I7HsxgNsB9AHwLMmLSb7SWeFIziG5B8AWAL2C44sfkpzbZwbHfpVkI8n7SRZldtNKruwA4HsA/pFkLg/Hvm1m882sFcCvAQwHcKuZNZvZ7wAcRPvW1JNm9nsza0aqCX4WyeEA/gbAO8GxWsxsDYCHAfx9ZN/HzGylmbWZ2QEzW2Rmn+yscGZ2B1IJdBqA+wB8mMPvKJUrqbk9DMClAP4OwCgAPQD8aw6/Y9YqurIzs7UAfovcRpLdGon3B8c7cln0229z5Lx7AHwAYCiAEQDOCLoMu0nuRuqb8riO9s1GMLjki0FZvp/LMaQyJTi39wOYb2brg3P9E4DPZXmMnCRhKsWbAawB8JPIssPXu3oC+CiIo/9AuRh+OAi6AHUA3kfqH3uFmU3tZN98X1OpAXBinseQypPE3H4lh30KoqJbdgBgZhuRaqpfE1m2HcCfAFzC1Ei0lyP/yuJzJM8heQxS1zdWmdlmpL59R5O8lGRt8GcSybG5nIRkFcmvkuzPlMkArkJqLgTxSNJyOzAfwGUkP06yJ4Abg/PEruIru8CtSF3Ej/oKgBsA7ATwCQDP5XmORUh9034A4HSkmvMwsyYAnwXwJaS+Df8M4J+B9KM5kpxJ8rVOzvUFAG8BaAJwP1LXNIpyXUPKTqJy28zmAfgVgFVIDZzQjEhlHicNBCAiXkhKy05EpFOq7ETEC3lVdiSnBU9VbySpiYQlMZTbyZPzNTum3oNbD2AqUk/5vwBghpm9XrjiiRSfcjuZ8nnObjKAjWa2CQBIPojUTOppE+IYdrPuR91YklJowq4dGpY9LeV2hTqAvThozexoXT6V3fFo//T0FqTeyUurO3rhDGoek3Lw3/YbzZeQnnK7Qq2y9I+j5lPZdVR7HtUnJjkbwGwA6I6eeZxOpGiU2wmUzw2KLYi8ZoLUC77vH7mRZk2XCqTcTqB8KrsXAIwieULwmsmXkJpJXaTSKbcTKOdurJm1kLwawBIA1QDmmVlnr0CJVATldjLlNeqJmT0F4KkClUWkbCQttw+eN9HFdt2OjPbp8c3uLm5b+0bBy1RseoNCRLygyk5EvJCEwTtFpAt7hta6+IXxj2a0z9QBl7k4Ca2iJPwOIiJdUmUnIl5IdDd229Vnu/jAwPjPN/I/d7m47eV18Z9QRDKmlp2IeEGVnYh4IRHdWHYL30vcPf1UF3/7mgdcfFHv+OeYPmHwbBcPfOEsF/dfv9/FXPlS7OUQkaOpZSciXlBlJyJeUGUnIl5IxDW7qn59XTz/trtcPPaY4g6o+PaFc8MPF4bhKatnuHjoh2NcnISXq6V81Qwf5uKmkR2OVO4VtexExAuq7ETEC4noxpa7lycvdvGUu8L+bbfPlqI0kmRVffq4eP1V4cjyG/7hnlIUp6yoZSciXlBlJyJeUDdWJEHevGOci//nb38cWaMJvNWyExEvqLITES8kohvbtvMDF1878+sutpqOH6Tc9n8PuDh6p1Sk0ln3VhcPrlbXNarLlh3JeSS3kVwbWVZHcinJDcHP/vEWU6TwlNt+yaQbuwDAtCOWzQGwzMxGAVgWfBapNAug3PZGl91YM/s9yZFHLL4AwJQgXghgOYAbC1iurFhLi4uj48WlextwSMsEF08a/bUuj9/SKzzSE9+6s926YTW9MyyllJtKyO24bWnZ4+LP3/mtduvqX9/g4lZUvlxvUAwxs0YACH4OLlyRREpKuZ1Qsd+gIDkbwGwA6I7ijkIiEifldmXJtbLbSrLezBpJ1gPYlm5DM5sLYC4AHMs6y/F8BRXt6tat7Hr76iHhl3vT9XpaJ+EqOrez1dQW5nP9f2xst651+/ZiFydWuf7PfRzArCCeBeCxwhRHpOSU2wmVyaMniwE8D+BkkltIXgHgDgBTSW4AMDX4LFJRlNt+yeRu7Iw0q84tcFnKSnSU1+hQOQOqK7K3Ih3wNbd9pQtQIuIFVXYi4oVEvBubLftU+FDxrtE9OtwmOkFJ+1Fe9b6h+OfgeRNdvGdobVb7Vh8KL/30+48XXWzNzfkXLAtq2YmIF1TZiYgXvOzGbrg8/LXfPv/nRT33sN67Xbxj4ngXW8PajjYX6VLV+HAu4mMH7O1y+22t4TZ3bY3M+hR5xxwAGMnP428OHzi+f+TyrMr39qHw/dsvf/hNF/dcvs7FbU1NWR0zF2rZiYgXVNmJiBe87MaWUrQLcNu8sPvxh092L0FpJAn23xUZeXv8o11uP3fX6S5+78x9Lq4e2H4wg3MXPO/iG+reyrl8J9SGw6Ct+MVcF0+dcZmLq1a8iLipZSciXlBlJyJeUDdWxGPVAwe6ePZzz7dbd37PXZFP2T1IXI7UshMRL6iyExEvqLITES/omp2Iz6rCAS9Orm0/An03dj2vximrwyEBm18Op9h94yv3dLR5O9P/bYmL77vp8+3W9XxkVZf7Z0stOxHxgio7EfGCl93YUfPCF54nPdv1JNmdueHbi1x8Ue8P8zqWSDFc1PePLn528YkuHl6TWdtnzLOXuvhj/6/axbtGZzdlwey+77v43/u2P3ccE1OqZSciXlBlJyJe8LIbm+0k2VV9+rj4zTvGtVs3snZH5FPlP2UuyTe6Npxa4L/GPBlZk34wimjXddg9YZ5z5ZrIgc8qSPniksm8scNJPkNyHcnXSF4bLK8juZTkhuBn/66OJVJOlNt+yaQb2wLgOjMbC+BMAFeRHAdgDoBlZjYKwLLgs0glUW57JJNJshsBNAZxE8l1AI4HcAGAKcFmCwEsB3BjLKUsMfYMZyB78q/vbrdu7DFx3DeSYlBuZ67P0+GYdAf7tbp4+y1nu9jGZTe0evSB5CHr9+dRusxkdYOC5EgApwJYBWBIkCyHk2ZwoQsnUizK7eTLuLIj2RvAwwC+YWYfZbHfbJINJBsOobjzRIpkQrnth4zuxpKsRSoZHjCzR4LFW0nWm1kjyXoA2zra18zmApgLAMeyLrunDjPEbt1cvHv6qS5urWVHm2etpVd4nD5VbXkda/n+8Pvl3tXnuHg0GvI6ruSm3HM7E++vqXfxAyMGuHhmn50FO8eOSWHX9aSTG128buwTOR9z8N3h3V+uLINh2UkSwL0A1pnZXZFVjwOYFcSzADxW+OKJxEe57ZdMWnafAnApgFdJHn5A7TsA7gDwEMkrALwHYHo8RRSJjXLbI5ncjX0WQLr+4LmFLU7mog/67psy1sWLb/+xi6OzGhVOfse85a2/dfHoK9V1LaVyze1snTAnHE79u4O+4OKZ5/+yYOd4+8K5XW+URrMdcvFPd4Uz6lXvDZcX4xqAXhcTES+oshMRL1Tsu7Etp53k4ujEu/l2M+OwqzWciHjXvvAB5eNKURhJNB4Ih1xadzDMu+hTBMNq4v8/Eu26Pr0vfNvumU9/3MW2fW3s5YhSy05EvKDKTkS8ULHd2EoyccVVLj756k0ubu1oY5E8nDzndRdfd+uFLm6cHl72efE7XU+Gk6/oXddo17V1x46ONi8KtexExAuq7ETEC+rGFkFbc3iHrHW3JuWR+LQ1RYZZisT1vw4f25366mWxl6PdA8NFvuuajlp2IuIFVXYi4oWK7cbWvr7FxZNu6nju11LO6RqdoCQ6T61IKbRu3+7iqhXbO9myMEo23lUn1LITES+oshMRL6iyExEvVOw1u+g1iLr5HV+D+GGvmS6+eWDsRWpn2PIDLm43kbCIlIRadiLiBVV2IuKFiu3GZmLwz54rdRFEpEyoZSciXlBlJyJeUGUnIl7IZJLs7iRXk3yZ5Gskvx8sryO5lOSG4Gf/ro4lUk6U237JpGXXDOCvzOwUABMATCN5JoA5AJaZ2SgAy4LPIpVEue2RLis7S9kTfKwN/hiACwAsDJYvBHBhB7uLlC3ltl8yumZHsprkSwC2AVhqZqsADDGzRgAIfg6Or5gi8VBu+yOjys7MWs1sAoBhACaTHJ/pCUjOJtlAsuEQmnMtp0gslNv+yOpurJntBrAcwDQAW0nWA0Dwc1uafeaa2UQzm1iLbnkWVyQeyu3ky+Ru7CCS/YK4B4DPAHgDwOMAZgWbzQLwWFyFFImDctsvmbwuVg9gIclqpCrHh8zstySfB/AQySsAvAdgeozlFImDctsjNCveAMoktwN4t2gnlM6MMLNBpS5EUii3y0bavC5qZSciUip6XUxEvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLzwv9NPrlrn6D7QAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ds = create_dataset(training=False)\n",
    "data = ds.create_dict_iterator().get_next()\n",
    "images = data['image']\n",
    "labels = data['label']\n",
    "\n",
    "for i in range(1, 5):\n",
    "    plt.subplot(2, 2, i)\n",
D
dyonghan 已提交
195
    "    plt.imshow(images[i][0])\n",
D
dyonghan 已提交
196 197 198 199 200 201 202 203 204 205 206
    "    plt.title('Number: %s' % labels[i])\n",
    "    plt.xticks([])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义模型\n",
    "\n",
207
    "MindSpore model_zoo中提供了多种常见的模型,可以直接使用。这里使用其中的LeNet5模型,模型结构如下图所示:\n",
D
dyonghan 已提交
208
    "\n",
D
dyonghan 已提交
209
    "<img src=\"https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg\">\n",
N
njzheng 已提交
210
    "\n",
D
dyonghan 已提交
211
    "[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf"
D
dyonghan 已提交
212 213 214 215 216 217 218 219
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练\n",
    "\n",
220
    "使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。\n",
D
dyonghan 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    "\n",
    "| batch size | number of epochs | learning rate | optimizer |\n",
    "| -- | -- | -- | -- |\n",
    "| 32 | 3 | 0.01 | Momentum 0.9 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 1875 ,loss is 2.3086565\n",
      "epoch: 2 step: 1875 ,loss is 0.22017351\n",
      "epoch: 3 step: 1875 ,loss is 0.025683485\n",
      "Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}\n"
     ]
    }
   ],
   "source": [
D
dyonghan 已提交
244 245
    "ds_train = create_dataset(num_epoch=3)\n",
    "ds_eval = create_dataset(training=False)\n",
D
dyonghan 已提交
246
    "\n",
D
dyonghan 已提交
247 248 249
    "net = LeNet5()\n",
    "loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
    "opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)\n",
D
dyonghan 已提交
250
    "\n",
D
dyonghan 已提交
251
    "loss_cb = LossMonitor(per_print_times=1)\n",
D
dyonghan 已提交
252
    "\n",
D
dyonghan 已提交
253 254 255 256
    "model = Model(net, loss, opt, metrics={'acc', 'loss'})\n",
    "model.train(3, ds_train, callbacks=[loss_cb])\n",
    "metrics = model.eval(ds_eval)\n",
    "print('Metrics:', metrics)"
D
dyonghan 已提交
257 258 259 260 261 262 263 264
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验步骤(方案二)\n",
    "\n",
265 266 267 268 269 270 271 272
    "除了Notebook,ModelArts还提供了训练作业服务。相比Notebook,训练作业资源池更大,且具有作业排队等功能,适合大规模并发使用。使用训练作业时,也会有修改代码和调试的需求,有如下三个方案:\n",
    "\n",
    "1. 在本地修改代码后重新上传;\n",
    "\n",
    "2. 使用[PyCharm ToolKit](https://support.huaweicloud.com/tg-modelarts/modelarts_15_0001.html)配置一个本地Pycharm+ModelArts的开发环境,便于上传代码、提交训练作业和获取训练日志。\n",
    "\n",
    "3. 在ModelArts上创建Notebook,然后设置[Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html),可以在线修改代码并自动同步到OBS中。因为只用Notebook来编辑代码,所以创建CPU类型最低规格的Notebook就行。\n",
    "\n",
D
dyonghan 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
    "### 代码梳理\n",
    "\n",
    "创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。\n",
    "\n",
    "```python\n",
    "import argparse\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--data_url', required=True, default=None, help='Location of data.')\n",
    "parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')\n",
    "parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')\n",
    "args, unknown = parser.parse_known_args()\n",
    "```\n",
    "\n",
    "MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:\n",
    "\n",
    "```python\n",
    "import moxing as mox\n",
    "mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')\n",
    "```\n",
    "\n",
    "如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:\n",
    "\n",
    "```python\n",
    "import moxing as mox\n",
    "mox.file.copy_parallel(src_url='output', dst_url='s3://OBS/PATH')\n",
    "```\n",
    "\n",
    "其他代码分析请参考方案一。\n",
    "\n",
    "### 创建训练作业\n",
    "\n",
    "可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。\n",
    "\n",
    "创建训练作业的参考配置:\n",
    "\n",
    "- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore\n",
    "- 代码目录:选择上述新建的OBS桶中的experiment_1目录\n",
    "- 启动文件:选择上述新建的OBS桶中的experiment_1目录下的`main.py`\n",
    "- 数据来源:数据存储位置->选择上述新建的OBS桶中的experiment_1目录下的MNIST目录\n",
    "- 训练输出位置:选择上述新建的OBS桶中的experiment_1目录并在其中创建output目录\n",
    "- 作业日志路径:同训练输出位置\n",
    "- 规格:Ascend:1*Ascend 910\n",
    "- 其他均为默认\n",
    "\n",
    "启动并查看训练过程:\n",
    "\n",
    "1. 点击提交以开始训练;\n",
    "2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;\n",
    "3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;\n",
    "4. 在训练日志中可以看到`epoch: 3 step: 1875 ,loss is 0.025683485`等字段,即训练过程的loss值;\n",
    "5. 在训练日志中可以看到`Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}`字段,即训练完成后的验证精度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验小结\n",
    "\n",
332
    "本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。"
D
dyonghan 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
D
dyonghan 已提交
352
   "version": "3.7.5"
D
dyonghan 已提交
353 354 355 356
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
D
dyonghan 已提交
357
}