{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "

训练时模型的保存和加载

\n", "\n", "## 实验介绍\n", "\n", "本实验主要介绍使用MindSpore实现训练时模型的保存和加载。建议先阅读MindSpore官网教程中关于模型参数保存和加载的内容。\n", "\n", "在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:\n", "\n", "- 训练后推理场景\n", "\n", " - 模型训练完毕后保存模型的参数,用于推理或预测操作。\n", "\n", " - 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。\n", "\n", "- 再训练场景\n", "\n", " - 进行长时间训练任务时,保存训练过程中的CheckPoint文件,防止任务异常退出后从初始状态开始训练。\n", "\n", " - Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。\n", "\n", "## 实验目的\n", "\n", "- 了解如何使用MindSpore实现训练时模型的保存。\n", "- 了解如何使用MindSpore加载保存的模型文件并继续训练。\n", "- 了解如何MindSpore的Callback功能。\n", "\n", "## 预备知识\n", "\n", "- 熟练使用Python,了解Shell及Linux操作系统基本知识。\n", "- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略、Checkpoint等。\n", "- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)、[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)、[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)、[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等功能。华为云官网:https://www.huaweicloud.com\n", "- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn/\n", "\n", "## 实验环境\n", "\n", "- MindSpore 0.2.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);\n", "- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html\n", "\n", "## 实验准备\n", "\n", "### 创建OBS桶\n", "\n", "本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。\n", "\n", "> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。\n", "\n", "创建OBS桶的参考配置如下:\n", "\n", "- 区域:华北-北京四\n", "- 数据冗余存储策略:单AZ存储\n", "- 桶名称:如ms-course\n", "- 存储类别:标准存储\n", "- 桶策略:公共读\n", "- 归档数据直读:关闭\n", "- 企业项目、标签等配置:免\n", "\n", "### 数据集准备\n", "\n", "MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)。\n", "\n", "从MNIST官网下载如下4个文件到本地并解压:\n", "\n", "```\n", "train-images-idx3-ubyte.gz: training set images (9912422 bytes)\n", "train-labels-idx1-ubyte.gz: training set labels (28881 bytes)\n", "t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)\n", "t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)\n", "```\n", "\n", "### 脚本准备\n", "\n", "从[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。\n", "\n", "### 上传文件\n", "\n", "将脚本和数据集上传到OBS桶中,组织为如下形式:\n", "\n", "```\n", "experiment_2\n", "├── MNIST\n", "│   ├── test\n", "│   │   ├── t10k-images-idx3-ubyte\n", "│   │   └── t10k-labels-idx1-ubyte\n", "│   └── train\n", "│   ├── train-images-idx3-ubyte\n", "│   └── train-labels-idx1-ubyte\n", "├── *.ipynb\n", "└── main.py\n", "```\n", "\n", "## 实验步骤(方案一)\n", "\n", "### 创建Notebook\n", "\n", "可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。\n", "\n", "创建Notebook的参考配置:\n", "\n", "- 计费模式:按需计费\n", "- 名称:experiment_2\n", "- 工作环境:Python3\n", "- 资源池:公共资源\n", "- 类型:Ascend\n", "- 规格:单卡1*Ascend 910\n", "- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的experiment_2文件夹\n", "- 自动停止等配置:默认\n", "\n", "> **注意:**\n", "> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。\n", "> - 打开Notebook后,选择MindSpore环境作为Kernel。\n", "\n", "> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的\"New\"->\"Terminal\",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。\n", "\n", "> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。\n", "\n", "导入MindSpore模块和辅助模块:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "# os.environ['DEVICE_ID'] = '0'\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import mindspore as ms\n", "import mindspore.context as context\n", "import mindspore.dataset.transforms.c_transforms as C\n", "import mindspore.dataset.transforms.vision.c_transforms as CV\n", "\n", "from mindspore.dataset.transforms.vision import Inter\n", "from mindspore import nn, Tensor\n", "from mindspore.train import Model\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n", "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", "\n", "import logging; logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 数据处理\n", "\n", "在使用数据集训练网络前,首先需要对数据进行预处理,如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "DATA_DIR_TRAIN = \"MNIST/train\" # 训练集信息\n", "DATA_DIR_TEST = \"MNIST/test\" # 测试集信息\n", "\n", "def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),\n", " rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):\n", " ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)\n", " \n", " # define map operations\n", " resize_op = CV.Resize(resize)\n", " rescale_op = CV.Rescale(rescale, shift)\n", " hwc2chw_op = CV.HWC2CHW()\n", " \n", " # apply map operations on images\n", " ds = ds.map(input_columns=\"image\", operations=[resize_op, rescale_op, hwc2chw_op])\n", " ds = ds.map(input_columns=\"label\", operations=C.TypeCast(ms.int32))\n", " \n", " ds = ds.shuffle(buffer_size=buffer_size)\n", " ds = ds.batch(batch_size, drop_remainder=True)\n", " ds = ds.repeat(num_epoch)\n", " \n", " return ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义模型\n", "\n", "定义LeNet5模型,模型结构如下图所示:\n", "\n", "\n", "\n", "[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class LeNet5(nn.Cell):\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')\n", " self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')\n", " self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.flatten = nn.Flatten()\n", " self.fc1 = nn.Dense(400, 120)\n", " self.fc2 = nn.Dense(120, 84)\n", " self.fc3 = nn.Dense(84, 10)\n", " \n", " def construct(self, input_x):\n", " output = self.conv1(input_x)\n", " output = self.relu(output)\n", " output = self.pool(output)\n", " output = self.conv2(output)\n", " output = self.relu(output)\n", " output = self.pool(output)\n", " output = self.flatten(output)\n", " output = self.fc1(output)\n", " output = self.fc2(output)\n", " output = self.fc3(output)\n", " \n", " return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 保存模型Checkpoint\n", "\n", "MindSpore提供了Callback功能,可用于训练/测试过程中执行特定的任务。常用的Callback如下:\n", "\n", "- `ModelCheckpoint`:保存网络模型和参数,用于再训练或推理;\n", "- `LossMonitor`:监控loss值,当loss值为Nan或Inf时停止训练;\n", "- `SummaryStep`:把训练过程中的信息存储到文件中,用于后续查看或可视化展示。\n", "\n", "`ModelCheckpoint`会生成模型(.meta)和Chekpoint(.ckpt)文件,如每个epoch结束时,都保存一次checkpoint。\n", "\n", "```python\n", "class CheckpointConfig:\n", " \"\"\"\n", " The config for model checkpoint.\n", "\n", " Args:\n", " save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.\n", " save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.\n", " Can't be used with save_checkpoint_steps at the same time.\n", " keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.\n", " keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.\n", " Can't be used with keep_checkpoint_max at the same time.\n", " integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.\n", " Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.\n", "\n", " Raises:\n", " ValueError: If the input_param is None or 0.\n", " \"\"\"\n", "\n", "class ModelCheckpoint(Callback):\n", " \"\"\"\n", " The checkpoint callback class.\n", "\n", " It is called to combine with train process and save the model and network parameters after traning.\n", "\n", " Args:\n", " prefix (str): Checkpoint files names prefix. Default: \"CKP\".\n", " directory (str): Lolder path into which checkpoint files will be saved. Default: None.\n", " config (CheckpointConfig): Checkpoint strategy config. Default: None.\n", "\n", " Raises:\n", " ValueError: If the prefix is invalid.\n", " TypeError: If the config is not CheckpointConfig type.\n", " \"\"\"\n", "```\n", "\n", "MindSpore提供了多种Metric评估指标,如`accuracy`、`loss`、`precision`、`recall`、`F1`。定义一个metrics字典/元组,里面包含多种指标,传递给`Model`,然后调用`model.eval`接口来计算这些指标。`model.eval`会返回一个字典,包含各个指标及其对应的值。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1875 ,loss is 2.3151364\n", "epoch: 2 step: 1875 ,loss is 0.3097728\n", "Metrics: {'acc': 0.9417067307692307, 'loss': 0.18866610953894755}\n", "b_lenet-1_1875.ckpt\n", "b_lenet-2_1875.ckpt\n" ] } ], "source": [ "os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件\n", "\n", "def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name=\"b_lenet\"):\n", " ds_train = create_dataset(num_epoch=num_epoch)\n", " ds_eval = create_dataset(training=False)\n", " steps_per_epoch = ds_train.get_dataset_size()\n", " \n", " net = LeNet5()\n", " loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", " opt = nn.Momentum(net.trainable_params(), lr, momentum)\n", " \n", " ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n", " ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)\n", " loss_cb = LossMonitor(steps_per_epoch)\n", " \n", " model = Model(net, loss, opt, metrics={'acc', 'loss'})\n", " model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)\n", " metrics = model.eval(ds_eval)\n", " print('Metrics:', metrics)\n", "\n", "test_train()\n", "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 加载Checkpoint继续训练\n", "\n", "```python\n", "def load_checkpoint(ckpoint_file_name, net=None):\n", " \"\"\"\n", " Loads checkpoint info from a specified file.\n", "\n", " Args:\n", " ckpoint_file_name (str): Checkpoint file name.\n", " net (Cell): Cell network. Default: None\n", "\n", " Returns:\n", " Dict, key is parameter name, value is a Parameter.\n", "\n", " Raises:\n", " ValueError: Checkpoint file is incorrect.\n", " \"\"\"\n", "\n", "def load_param_into_net(net, parameter_dict):\n", " \"\"\"\n", " Loads parameters into network.\n", "\n", " Args:\n", " net (Cell): Cell network.\n", " parameter_dict (dict): Parameter dict.\n", "\n", " Raises:\n", " TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.\n", " \"\"\"\n", "```\n", "\n", "> 使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1875 ,loss is 0.1638589\n", "epoch: 2 step: 1875 ,loss is 0.060048036\n", "Metrics: {'acc': 0.9742588141025641, 'loss': 0.07910804035148034}\n", "b_lenet_1-1_1875.ckpt\n", "b_lenet_1-2_1875.ckpt\n" ] } ], "source": [ "CKPT = 'b_lenet-2_1875.ckpt'\n", "\n", "def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name=\"b_lenet\"):\n", " ds_train = create_dataset(num_epoch=num_epoch)\n", " ds_eval = create_dataset(training=False)\n", " steps_per_epoch = ds_train.get_dataset_size()\n", " \n", " net = LeNet5()\n", " loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", " opt = nn.Momentum(net.trainable_params(), lr, momentum)\n", " \n", " param_dict = load_checkpoint(CKPT)\n", " load_param_into_net(net, param_dict)\n", " load_param_into_net(opt, param_dict)\n", " \n", " ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)\n", " ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)\n", " loss_cb = LossMonitor(steps_per_epoch)\n", " \n", " model = Model(net, loss, opt, metrics={'acc', 'loss'})\n", " model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb])\n", " \n", " metrics = model.eval(ds_eval)\n", " print('Metrics:', metrics)\n", "\n", "resume_train()\n", "print('\\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 加载Checkpoint进行推理\n", " \n", "使用matplotlib定义一个将推理结果可视化的辅助函数,如下:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def plot_images(pred_fn, ds, net):\n", " for i in range(1, 5):\n", " pred, image, label = pred_fn(ds, net)\n", " plt.subplot(2, 2, i)\n", " plt.imshow(np.squeeze(image))\n", " color = 'blue' if pred == label else 'red'\n", " plt.title(\"prediction: {}, truth: {}\".format(pred, label), color=color)\n", " plt.xticks([])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用训练后的LeNet5模型对手写数字进行识别,可以看到识别结果基本上是正确的。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAAD7CAYAAAAMyN1hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcv0lEQVR4nO3de5RU5Znv8e9D03TLJUIraCMgXgDJMlEZguRoEuJl1IkecpxlJo7jQpfaIdE1OmO8xMl9NPHkmMvMMpmIE4TxbtBRYuIkSjQRZVCWgSSKCqMIKHJROgEEpJvn/FGbXbuaqn5316Wrqvv3WatXP7v27d3VTz+1330rc3dERKSwAdVugIhIrVOhFBEJUKEUEQlQoRQRCVChFBEJUKEUEQmoyUJpxjwzbozij5nxSpHL+bEZXylv6+qDGePNcDMGVrstkqXcLl01crsmC2WSO0+7Myk0nRkXmbG4y7yz3fnnyrUuXvexZvzSjC1m9OjCVDNmmLG+DG1YY8ZppS4nsbxPmvGkGX8yY025litZyu3Uy6l6ble8UPaTPZo9wAPAJZVYeJXewx3AXOCaKqy7Lii3S1c3ue3uPf4BXwP+JfCXwLeC3wHeHI2bAb4e/Drwt8HvjF4/G3w5eDv4s+AfTizvBPAXwLeB3w9+H/iNyeUlph0L/hD4ZvB3wG8Fnwy+C7wTfDt4ezTtvH3LiYYvA18N/i74QvDRiXEOPht8VbRNPwS3Hr4vR4N7D6YfAr4TfG/U7u3go8G/Dr4A/C7wP4Nfmmdb4vcF/M5oGTujZVwLPj7aplnga8G3gP9TEX/r08DXFJMn9fij3FZu5/spZY/yAuAM4ChgIvDlxLhDgRbgcKDNjClkKvjngIOA24CFZjSZMQh4GLgzmuenwF/nW6EZDcCjwBvAeOAw4D53VgKzgSXuDHVneJ55TwG+DXwGaI2WcV+Xyc4GPgIcF013RjTvODPazRiX9s1Jw50dwFnAW1G7h7rzVjR6JrAAGA7cHVjOhcBa4JxoGd9JjD4ZmAScCnzVjMnRNp1sRns5t6cPUW6XqK/ldimF8lZ31rnzLnATcH5i3F7ga+7sdmcncBlwmztL3el0Zz6wG5ge/TQCP3BnjzsLgOcLrHMaMBq4xp0d7uxyzz12040LgLnuvODObuBLwEfNGJ+Y5mZ32t1ZCzwJHA/gzlp3hkev95Yl7jzszt7oPSzWN9zZ6c4KYAWZfxTcWZzvn04A5Xal1V1ul1Io1yXiN8j8kffZ7M6uxPDhwNXRJ1d7VO3HRvOMBt50zzlQ/EaBdY4F3nCno4j2jk4u153twDtkPrn3eTsRvwcMLWI95bIuPEkqtbRN9UK5XVl1l9ulFMqxiXgcxLvVwH5nx9YBN0WfXPt+BrtzL7ABOMwM67K8fNYB4wocAA6dkXuLTFIDYMYQMl2lNwPzVVqhdnd9fQcwODF8aMrlSM8pt8ujz+R2KYXycjPGmNEC3ADc3820twOzzTjRDDNjiBmfMmMYsAToAP7ejIFmnEumG5LPc2SS7+ZoGc1mnBSN2wiMiY4L5XMPcLEZx5vRBHwLWOpe+qUv0TY1Q2bdUbuaEuPnmTGvwOwbgYPMODCwmuXAX5nRYsahwFV5lnNkURuQhxkDom1qhMz2dfPe9jXK7YhyO6OUQnkP8CvgtejnxkITurOMzLGcW4GtwGrgomjc+8C50fBW4G+AhwospxM4BziazAHe9dH0AL8GXgTeNmNLnnkXAV8BHiSTkEcBn02zodEB7+3dHPA+HNgZrZ8oTl5IPBZ4psA2vQzcC7wWdd1G55uOzAmBFcAaMu9713/ebwNfjpbxxcAm7bvYeXs3k3w82o5fkNkL2hmttz9QbmcptwHLnCbvmegizUvdeaLHM/cz0SfVCuDD7uypdnuke8rt9PpTbveHC2arKtqrmFztdoiUW3/K7Zq/hVFEpNqK6nqLiPQnJe1RmtmZZvaKma02s+vL1SiRalNuS1LRe5Rm1gC8CpxO5gzd88D57v5S+Zon0vuU29JVKSdzpgGr3f01ADO7j8w9nAWTaZA1eTNDSlillMs2tm5x95HVbkeNUm7XqV3s4H3fbeEpe6aUQnkYubcirQdO7DqRmbUBbQDNDOZEO7WEVUq5POELCt1KJ8rturXUF1VkuaUco8xXtffrx7v7HHef6u5TG7MX9IvUMuW25CilUK4n957YMeTeEytSr5TbkqOUQvk8MMHMjjCzQWRumVpYnmaJVJVyW3IUfYzS3TvM7Argl0ADMNfdXwzMJlLzlNvSVUm3MLr7L8jcWC7Spyi3JUm3MIqIBKhQiogEqFCKiASoUIqIBKhQiogE6MG9KQwYnP3eo7VXHp8zbm+KGzJGLs9+sd4BDz9XtnaJSO/QHqWISIAKpYhIgLreKdiw7PeqL/jcLTnjJg8a3HXy/Rzx2KVxPPHh8rVLJK2BR46P402faE01z8hHV8dx5+bN5W5SXdEepYhIgAqliEiACqWISICOUYr0A8njks/f9G+p5pnS9Pk4PvTB7Ov98Xil9ihFRAJUKEVEAtT1FpG8Xvhqtov+kd3ZbnjLHep6i4hIFyqUIiIBKpQiIgEqlCIiASqUIiIBOust0g8kH3CRvJA8eWZbCgvuUZrZXDPbZGZ/TLzWYmaPm9mq6PeIyjZTpPyU25JWmq73PODMLq9dDyxy9wnAomhYpN7MQ7ktKQS73u7+WzMb3+XlmcCMKJ4PPAVcV8Z2iVRcf8rt5P3ZB60c0+P5P37F0jhe3PnROB7+H0tKa1idKPZkziHuvgEg+j2q0IRm1mZmy8xs2R52F7k6kV6j3Jb9VPyst7vPcfep7j61kRTfxCVSJ5Tb/UexZ703mlmru28ws1ZgUzkbJVJFyu08vtv6QhxPPnJ6HA+vRmOqoNg9yoXArCieBTxSnuaIVJ1yW/aT5vKge4ElwCQzW29mlwA3A6eb2Srg9GhYpK4otyWtNGe9zy8w6tQyt0WkVym3JS3dwigiEqBCKSISoHu9C2gYOTKON5x3dBwPG7C3Gs0RkSrSHqWISIAKpYhIgLreBez5YPZ+2N/d8KPEmKG93xgRqSrtUYqIBKhQiogEqOudYE3ZBxu8/4HGkpa1qXNHdrm7GkpalohUl/YoRUQCVChFRALU9U5oP++EOL73plsSY3p+pvuku78Yx5Nuir+SBV2uLlJ/tEcpIhKgQikiEqBCKSISoGOUCZ2NFsdHNJZ2B87AXdll7d22raRliUh1aY9SRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkIM33eo81syfNbKWZvWhmV0avt5jZ42a2Kvo9ovLNFSkf5baklWaPsgO42t0nA9OBy83sg8D1wCJ3nwAsioZF6olyW1IJXnDu7huADVG8zcxWAocBM4EZ0WTzgaeA6yrSygrqnDEljredtb2KLZHe1tdzW8qnR8cozWw8cAKwFDgkSrR9CTeqwDxtZrbMzJbtYXdprRWpEOW2dCd1oTSzocCDwFXu/ue087n7HHef6u5TG2kKzyDSy5TbEpLqXm8zaySTSHe7+0PRyxvNrNXdN5hZK7CpUo2spDdnNMfxyyfPrWJLpBr6cm5L+aQ5623AT4CV7v69xKiFwKwongU8Uv7miVSOclvSSrNHeRJwIfAHM1sevXYDcDPwgJldAqwFzqtME0UqRrktqaQ5670YsAKjTy1vc3pHw4Qj43hX654qtkSqqS/mtlSG7swREQlQoRQRCeiXTzhfeX1LHL9+1u1lW+76juwF6wN0WZ3UKOvwOH51z444PmrgATnTNVj+/aiO5uz8A4YNi+O+/CR/7VGKiASoUIqIBPTLrnelnPOda+N43Nzlcby3Go0RKWDA0j/G8VVnXRzHP3jsjpzpJjYOyTv/MxfcEsfTh/1jHE+4fGm5mlhztEcpIhKgQikiEqCud4mmfPPzcdz64Ko47nzvvWo0RyTIOzqyA++0x2GnF7r2PteohmyX3Js7y9auWqY9ShGRABVKEZEAdb1TSF5InjyzDV2625s391qbRKT3aI9SRCRAhVJEJECFUkQkoF8eoxz7s+znw+R1XwhOn3zARfKOG9BlQCL9gfYoRUQCVChFRAL6Zdf7gIefi+NxD/dsXj3gQvoSf29nHH/q5/+QOy7FXTcHL24se5tqkfYoRUQCVChFRAKCXW8zawZ+CzRF0y9w96+ZWQtwPzAeWAN8xt23Vq6pIuWl3M79+oa+/DzJUqXZo9wNnOLuxwHHA2ea2XTgemCRu08AFkXDIvVEuS2pBAulZ+y72bkx+nFgJjA/en0+8OmKtFCkQpTbklaqY5Rm1mBmy4FNwOPuvhQ4xN03AES/R1WumSKVodyWNFIVSnfvdPfjgTHANDM7Nu0KzKzNzJaZ2bI96DtcpbYotyWNHp31dvd24CngTGCjmbUCRL83FZhnjrtPdfepjTSV2FyRylBuS3eChdLMRprZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzZ04rMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKsFC6e6/B07I8/o7wKmVaJRIb1BuS1rm7r23MrPNwBu9tkLpzuHuPrLajegrlNs1oyJ53auFUkSkHulebxGRABVKEZGAmiyUZswz48Yo/pgZrxS5nB+b8ZXytq4+mDHeDDfrn88crVXK7dJVI7drslAmufO0O5NC05lxkRmLu8w7251/rlzr4nUfa8YvzdhiRo8O+poxw4z1ZWjDGjNOK3U5ieV90ownzfiTGWvKtVzJqofcjtZ/pBmPmrEtyvHvpJyvJnM7WuYUM35rxnYzNppxZXfTV7xQ9pM9mj3AA8AllVh4ld7DHcBc4JoqrLsu9IfcNmMQ8Djwa+BQMrd63lXG5ff6e2jGwcB/AbcBBwFHA7/qdiZ37/EP+BrwL4G/BL4V/A7w5mjcDPD14NeBvw1+Z/T62eDLwdvBnwX/cGJ5J4C/AL4N/H7w+8BvTC4vMe1Y8IfAN4O/A34r+GTwXeCd4NvB26Np5+1bTjR8Gfhq8HfBF4KPToxz8Nngq6Jt+iG49fB9OTp6Jk3a6YeA7wTfG7V7O/ho8K+DLwC/C/zP4Jfm2Zb4fQG/M1rGzmgZ14KPj7ZpFvha8C3g/1TE3/o08DXF5Ek9/ii393s/2sCfLuJ9rNncBv/Wvr9d2p9S9igvAM4AjgImAl9OjDsUaAEOB9rMmEJm7+RzZCr4bcBCM5qiT6yHgTujeX4K/HW+FZrRADxK5nq18cBhwH3urARmA0vcGerO8DzzngJ8G/gMmTsy3gDu6zLZ2cBHgOOi6c6I5h1nRrsZ49K+OWm4swM4C3gravdQd96KRs8EFgDDgbsDy7mQzB0k50TLSHaNTgYmkbmA+qtmTI626WQz2su5PX2IcjtrOrDGjMeibvdTZnyowLSxGs/t6cC7ZjxrxiYzfhb63y6lUN7qzjp33gVuAs5PjNsLfM2d3e7sBC4DbnNnqTud7swn89DU6dFPI/ADd/a4swB4vsA6pwGjgWvc2eHOLvfcYzfduACY684L7uwGvgR81IzxiWludqfdnbXAk2Qe5oo7a90ZHr3eW5a487A7e6P3sFjfcGenOyuAFWT+UXBncb5/OgGU20ljgM8C/xq17+fAI9GHQLGqndtjyNzDfyUwDngduLe7FZVSKNcl4jfIvIn7bHZnV2L4cODq6JOrPar2Y6N5RgNvuuecBCl0h8NY4A13Oopo7+jkct3ZDrxD5pN7n7cT8XvA0CLWUy7rwpOkUkvbVC+U21k7gcXuPObO+8AtZPacJxfRzn2qnds7gf905/nob/kN4H+ZcWChGUoplGMT8TiId6uB/c78rgNuij659v0MdudeYANwmBnWZXn5rAPGFTgAHDrb/BaZpAbAjCFk/uBvBuartELt7vr6DmBwYvjQlMuRnlNuZ/0+xfoLqdXc7rpN+2LLMy1QWqG83IwxZrQAN5D5MqZCbgdmm3GiGWbGEDM+ZcYwYAnQAfy9GQPNOJdMNySf58gk383RMprNOCkatxEY002X4B7gYjOON6MJ+Baw1L30S1+ibWqGzLqjdjUlxs8zY16B2TcCB3X3aRZZDvyVGS1mHApclWc5Rxa1AXmYMSDapkbIbF+J3a16otzOuguYbsZp0XHUq4AtwEqoz9wG7gD+T/R+NQJfIbPXXPC4ZimF8h4yp9Rfi35uLDShO8vIHMu5FdgKrAYuisa9D5wbDW8F/gZ4qMByOoFzyJzOXwusj6aHzOULLwJvm7Elz7yLyLwhD5JJyKPIHHsJig54b+/mgO/hZHbnX4yGd0LOhcRjgWcKbNPLZI6PvBZ13Ubnm47MCYEVZL4V8Ffs/8/7beDL0TK+GNikfRc7b+9mko9H2/ELMntBOwldQtF3KLezy34F+Dvgx9E2zAT+d7RtUIe57c6vyXwA/pzMQ5mPBv6222VmTpf3jGUuQL7UnSd6PHM/E+0FrAA+7M6eardHuqfcTq8/5Xafv2C22qJP3lIOfIvUpP6U2zV/C6OISLUVVSjdGe/OE2Z2ppm9YmarzUxfEi91T7kt+RT94N7oe0ZeBU4nc+D5eeB8d3+pfM0T6X3KbemqlGOU04DV7v4agJndR+aMWMFkGmRN3syQElYp5bKNrVtcXwVRiHK7Tu1iB+/77oLXQxarlEJ5GLlX2K8HTuxuhmaGcKLpO5tqwRO+QN/vUphyu04t9UUVWW4phTJf1d6vH29mbUAbQHPOxfciNUu5LTlKOeu9ntxbvcaQe6sXAO4+x92nuvvUxuzNKiK1TLktOUoplM8DE8zsCDMbROZOgIXlaZZIVSm3JUfRXW937zCzK4BfAg3AXHd/MTCbSM1TbktXJd2Z4+6/IHMvsEifotyWJN2ZIyISoEIpIhKgQikiEqBCKSISoEIpIhKg51GK9FGdM6bE8ZszmvNOM2B3Nh73L8tzxu19772KtKseaY9SRCRAhVJEJEBd7xK9d272oTK7Dsz/uTPi1ex3vNszy/NOI9ITafJu21nZ79d6+eS5eadZ35Gd5pxt1+aMa71/VRx3bt5cVDv7Cu1RiogEqFCKiASoUIqIBOgYZQo2MPs27T3x2JxxF970szhuO3C/RxYCcMRjl8bxxLxfFS/SM2d8/Tdx/OWDXy56OWMGDo3j393wo5xxp//h4jge8BsdoxQRkW6oUIqIBKjrncKAg1ri+Pt353ZPJg/Sd6VIBVn263saDj44jpsGvFaN1vRb2qMUEQlQoRQRCVDXW6SGJbvbbc8uieOzBm9NTNXYiy3qn7RHKSISoEIpIhKgrrdIjbGp2ZsaPjkvf3e7ycLd7eOeOz+OR30///MouzPwhdVxvLfHc/ctwT1KM5trZpvM7I+J11rM7HEzWxX9HlHZZoqUn3Jb0krT9Z4HnNnlteuBRe4+AVgUDYvUm3kotyWFYNfb3X9rZuO7vDwTmBHF84GngOvK2K6qG3DsMXG891+3xfHYgTqs21fUam53Dsl2q69p+Z/EmHB3+5jFF8bxuH9piGN75nc9bkd/724nFftff4i7bwCIfo8qX5NEqkq5Lfup+MkcM2sD2gCa0e1+0ncot/uPYgvlRjNrdfcNZtYKbCo0obvPAeYAfMBavMj19bqOgw6I48ePuS8xpudnD6Wu1HVu20vDsvEzz1axJX1LsV3vhcCsKJ4FPFKe5ohUnXJb9pPm8qB7gSXAJDNbb2aXADcDp5vZKuD0aFikrii3Ja00Z73PLzDq1DK3paY0vvWnOD7q19knPS/7xA9zphvRkP/Y1MVrPxbHBy/Wvbi1qFZye8Bxk3OGX/3bnh0RO33lOXE8cnlHWdqUVueMKXH85oziD0uN/8+tOcN7V6wselmVoGtdREQCVChFRAJ0r3cBnauyT5CedHX2Urq3l+ZON6KBvP77vz4Ux+Pu0NlHKWzLlOE5w6+f8289mr/9rjFx3PLwkm6mLL9kd3tl24+6mbJ7R4xqyxk+5rbs4Yha6IZrj1JEJECFUkQkQF3vAgYMy164u2Pa+DhuNt0BK/1bw4Qj43hX656yLPP1T8/JGZ686QtxPG5FWVZREu1RiogEqFCKiASo611Ax5Sj4/g3tyW7BUN7vzEiVdYw/MA4fuVr2fj1U26vyPo6mrO3zicPg+3dti3f5BWnPUoRkQAVShGRAHW9RSTozXmj43jZXySfd1CZ53A+c8EtcTx92D/G8YTLl+abvOK0RykiEqBCKSISoEIpIhKgY5RldMzt2bsJjvrJ2jju3ScEipTHu49OjOOffujf43hEw5AeLef1Pdvj+KIrsscbL/jOo3HcduBbOfOMSqzDmzt7tL5K0B6liEiACqWISIC63mU0bE32boKOdeur2BKpJ6N+syFnOHkI5+XLws94/PgV2UtmFnd+NI6H/0dpz6acOe73cTyxsWfd7Qe2Z+/e+b/f/XwcN8zeEsczBq9KzNGz5fc27VGKiASoUIqIBKjrneAnHR/Hm/5hV6p5jvjZZXF8zAvtcaynVkpaHa+tyRk+6t+z10lMaMp2W5N3qyTPCn+39YU4vvgL2def+sTUOB720qA4bv1uuq8meXDOKXHcNDv73MlrWv4nOO+ru1qzbb0z+0DJtS3Z/7E1k7JfgTGxsTzPtayUNN/rPdbMnjSzlWb2opldGb3eYmaPm9mq6PeIyjdXpHyU25JWmq53B3C1u08GpgOXm9kHgeuBRe4+AVgUDYvUE+W2pBLserv7BmBDFG8zs5XAYcBMYEY02XzgKeC6irSyl2ydeEAcr5g2L9U8R9+T7SbVwrfFSXq1mtvJKyYmfP/9OH7nsxbHowp8++cd457ODiTi/3fiUXE8d9gZZWhl96YNznbP7772lG6mrA89OpljZuOBE4ClwCFRou1LuFGF5xSpbcpt6U7qQmlmQ4EHgavc/c89mK/NzJaZ2bI97C6mjSIVpdyWkFRnvc2skUwi3e3uD0UvbzSzVnffYGatwKZ887r7HGAOwAesxfNNI1ItNZ/bu7MF+LPLL4njn56Qvfc6zcXgyTPV17SFL2Iv1V8Ozp7FXtkL66u0NGe9DfgJsNLdv5cYtRCYFcWzgEfK3zyRylFuS1pp9ihPAi4E/mBmy6PXbgBuBh4ws0uAtcB5lWmiSMUotyWVNGe9FwNWYPSp5W2OSO+ph9zubP9THB/66Wx8waMXx3Hynuy/HPaHOJ7W1Fjh1pVP8t5wyL1gfdCG6m+HbmEUEQlQoRQRCdC93iJ1qOXsV+P4aZrjeO4Pr4jjn3/q+73aplJ875uzc4YPvOu/43g8pT0urhy0RykiEqBCKSISoK63SB8y6fqX4vjqb366ii3pmeHtv8sZrrU7U7RHKSISoEIpIhKgrrdIH7J327bsQDKWkmiPUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJECFUkQkQIVSRCRAhVJEJEB35iSMeHVnHB/x2KWp5pn81rtx3Fn2FolILdAepYhIgAqliEiAut4J9szyOJ74TLp51N0W6fuCe5Rm1mxmz5nZCjN70cy+Eb3eYmaPm9mq6PeIyjdXpHyU25JWmq73buAUdz8OOB4408ymA9cDi9x9ArAoGhapJ8ptSSVYKD1jezTYGP04MBOYH70+H6if586LoNyW9FKdzDGzBjNbDmwCHnf3pcAh7r4BIPo9qnLNFKkM5bakkapQununux8PjAGmmdmxaVdgZm1mtszMlu1hd7HtFKkI5bak0aPLg9y9HXgKOBPYaGatANHvTQXmmePuU919aiNNJTZXpDKU29KdNGe9R5rZ8Cg+ADgNeBlYCMyKJpsFPFKpRopUgnJb0kpzHWUrMN/MGsgU1gfc/VEzWwI8YGaXAGuB8yrYTpFKUG5LKubee181bmabgTd6bYXSncPdfWS1G9FXKLdrRkXyulcLpYhIPdK93iIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgEqlCIiASqUIiIBKpQiIgH/HyKwOLKinO/gAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "CKPT = 'b_lenet_1-2_1875.ckpt'\n", "\n", "def infer(ds, model):\n", " data = ds.get_next()\n", " images = data['image']\n", " labels = data['label']\n", " output = model.predict(Tensor(data['image']))\n", " pred = np.argmax(output.asnumpy(), axis=1)\n", " return pred[0], images[0], labels[0]\n", "\n", "ds = create_dataset(training=False, batch_size=1).create_dict_iterator()\n", "net = LeNet5()\n", "param_dict = load_checkpoint(CKPT, net)\n", "model = Model(net)\n", "plot_images(infer, ds, model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 实验步骤(方案二)\n", "\n", "### 代码梳理\n", "\n", "创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。\n", "\n", "```python\n", "import argparse\n", "parser = argparse.ArgumentParser()\n", "parser.add_argument('--data_url', required=True, default=None, help='Location of data.')\n", "parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')\n", "parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')\n", "args, unknown = parser.parse_known_args()\n", "```\n", "\n", "MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:\n", "\n", "```python\n", "import moxing as mox\n", "mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')\n", "```\n", "\n", "如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:\n", "\n", "```python\n", "import moxing as mox\n", "mox.file.copy_parallel(src_url='output', dst_url='s3://OBS/PATH')\n", "```\n", "\n", "其他代码分析请参考方案一。\n", "\n", "### 创建训练作业\n", "\n", "可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。\n", "\n", "创建训练作业的参考配置:\n", "\n", "- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore\n", "- 代码目录:选择上述新建的OBS桶中的experiment_2目录\n", "- 启动文件:选择上述新建的OBS桶中的experiment_2目录下的`main.py`\n", "- 数据来源:数据存储位置->选择上述新建的OBS桶中的experiment_2文件夹下的MNIST目录\n", "- 训练输出位置:选择上述新建的OBS桶中的experiment_2目录并在其中创建output目录\n", "- 作业日志路径:同训练输出位置\n", "- 规格:Ascend:1*Ascend 910\n", "- 其他均为默认\n", "\n", "启动并查看训练过程:\n", "\n", "1. 点击提交以开始训练;\n", "2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;\n", "3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;\n", "4. 在训练日志中可以看到`epoch: 3 step: 1875 ,loss is 0.025683485`等字段,即训练过程的loss值;\n", "5. 在训练日志中可以看到`Metrics: {'acc': 0.9742588141025641, 'loss': 0.08628832848253062}`等字段,即训练完成后的验证精度;\n", "6. 在训练日志里可以看到`b_lenet_1-2_1875.ckpt`等字段,即训练过程保存的Checkpoint。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 实验小结\n", "\n", "本实验展示了MindSpore的Checkpoint、断点继续训练等高级特性:\n", "\n", "1. 使用MindSpore的ModelCheckpoint接口每个epoch保存一次Checkpoint,训练2个epoch并终止。\n", "2. 使用MindSpore的load_checkpoint和load_param_into_net接口加载上一步保存的Checkpoint继续训练2个epoch。\n", "3. 观察训练过程中Loss的变化情况,加载Checkpoint继续训练后loss进一步下降。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }