{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#
混合精度训练体验" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述\n", "\n", "神经网络训练的时候,数据和权重等各种参数一般使用单精度浮点数(float32)进行计算和存储。在采用复杂神经网络进行训练时,由于计算量的增加,机器的内存开销变得非常大。经常玩模型训练的人知道,内存资源的不足会导致训练的效率变低,简单说就是训练变慢,有没有什么比较好的方法,在不提升硬件资源的基础上加快训练呢?这次我们介绍其中一种方法--混合精度训练,说白了就是将参数取其一半长度进行计算,即使用半精度浮点数(float16)计算,这样就能节省一半内存开销。当然,为了保证模型的精度,不能把所有的计算参数都换成半精度。为了兼顾模型精度和训练效率,MindSpore在框架中设置了一个自动混合精度训练的功能,本次体验我们将使用ResNet-50网络进行训练,体验MindSpore混合精度训练和单精度训练的不同之处。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "整体过程如下:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. MindSpore混合精度训练的原理介绍。\n", "2. 数据集准备。\n", "3. 定义动态学习率。\n", "4. 定义损失函数。\n", "5. 定义ResNet-50网络。\n", "6. 定义`One_Step_Time`回调函数。\n", "7. 定义训练网络(此处设置自动混合精度训练参数`amp_level`)。\n", "8. 验证模型精度。\n", "9. 混合精度训练和单精度训练的对比。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> 你可以在这里找到完整可运行的样例代码:。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MindSpore混合精度训练原理介绍" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![image](https://www.mindspore.cn/tutorial/zh-CN/r0.7/_images/mix_precision.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. 参数以FP32存储;\n", "2. 正向计算过程中,遇到FP16算子,需要把算子输入和参数从FP32 `cast`成FP16进行计算;\n", "3. 将Loss层设置为FP32进行计算;\n", "4. 反向计算过程中,首先乘以Loss Scale值,避免反向梯度过小而产生下溢;\n", "5. FP16参数参与梯度计算,其结果将被cast回FP32;\n", "6. 除以`Loss scale`值,还原被放大的梯度;\n", "7. 判断梯度是否存在溢出,如果溢出则跳过更新,否则优化器以FP32对原始参数进行更新。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上可以理解(float16为半精度浮点数,float32为单精度浮点数),MindSpore是将网络中的前向计算部分`cast`成半精度浮点数进行计算,以节省内存空间,提升性能,同时将`loss`值保持单精度浮点数进行计算和存储,`weight`使用半精度浮点数进行计算,单精度浮点数进行保存,通过这样操作即提升了训练效率,又保证了一定的模型精度,达到提升训练性能的目的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集准备" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "数据集下载地址:。\n", "\n", "数据集下载后,解压至jupyter的工作路径+/datasets/cifar10,由于测试数据集和训练数据集在一个文件夹中,需要你分开两个文件夹存放,存放形式如下。" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "cifar10\n", "├── test\n", "│   └── test_batch.bin\n", "└── train\n", " ├── data_batch_1.bin\n", " ├── data_batch_2.bin\n", " ├── data_batch_3.bin\n", " ├── data_batch_4.bin\n", " └── data_batch_5.bin\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果放置正确,可以在开启的jupyter的首页网址+`/tree/datasets/cifar10`,找到`test`和`train`文件夹。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 数据增强" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "先将CIFAR-10的原始数据集可视化:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the cifar dataset size is : 50000\n", "the tensor of image is: (32, 32, 3)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAActElEQVR4nO2da4xdV3XH/+vcx7z8jmPHOE4MITwCgUDdNFIookBRSpECUqHQCuUDxagCqUj0Q0SlQqV+gKqA+FBRmRKRVpSQ8hBpQS0ogiIqNWQIwQlxCI7jhImfiV8znrmvc1Y/3BvkhP1fM57HvSn7/5Msz+x19znr7nvWPXf2/661zN0hhPjNpxi1A0KI4aBgFyITFOxCZIKCXYhMULALkQkKdiEyob6SyWZ2E4DPAqgB+Cd3/0T0+ImJSd+wcRM7WnQiYuCyYXA0uFfBPD6zIjJlN5hTG2tym/H32jLwsVbU+LxuLzletefpnCJceu5jtMZWpOex8f7x+BGj18yrMvCEXSPLud6ASKleroxdLiwkx9udFp9US69ju9NFr1cmn8Cyg93MagD+AcDvA5gBcK+Z3eXuD7E5GzZuwh//6Z+lj1fjrtTqjeR4FbzIteC17HXTiwsA9SCQOr1OcvxJT/sHAJt3v4ja1k2MU9tsEJzrxjdQ29mTTyfHF37+EzpnfZMHUmNsktqiN6vmWPq5NafWB8fj69hpzXFbe5baUHXT4wW/3ooi8KPLr7my4sFelnyN5x7+WXL80UMH6Jxq40Ry/OGDv6RzVvIx/noAB939kLt3ANwB4OYVHE8IsYasJNh3ArjwbWRmMCaEeB6ykmBPfVD+tc8xZrbXzKbNbHphnn80FUKsLSsJ9hkAuy74/XIAR577IHff5+573H3PxCT/+08IsbasJNjvBXC1mb3QzJoA3g3grtVxSwix2ix7N97de2b2IQD/hb70dpu7p7cVfzUH6FVsmzyQoZDeIe90yU4rgMLTO+cA4JGkEcgnHbKh+th8egccAGznFdS2+ZIt1LZpfB21PXns1z5A/YqHjs0kx7ct8Odca/F1bExxWWOc7LgDgFv6mFWN72YXRHUBgFaLv5694LkVlj6fI7h2anznvApCplfya6fbDaRDslNvgZLXJUpUJP+tSGd3928D+PZKjiGEGA76Bp0QmaBgFyITFOxCZIKCXYhMULALkQkr2o2/eBwo01lZQVITld6CpDdUbS7HdBfOU1sRJOT4eDqJ47LtXEKrAplv5iSX0Lqd4NuGQQLQtnXpLy7ViAQFAIiy6Npc8krnVvVpNEjyEnn9AaAK5NIqSDIJTPCKyFrBc47kq6ri/ofXI/EDAHrk4o/uxAVNQuIviu7sQmSCgl2ITFCwC5EJCnYhMkHBLkQmDHU33t1R9tIJCB7WJkvvgHZawY51sAvuQTmlWpMndxS19A7zutoYnVOniT+Az/PyWO1ATagHO8LrWul5wR4y6uM89bgIdnfZawkAVTe9s251vivNyn4BwMICf63rUTE8croiql8Y3AIjVSAqS8XqFwJARX3hc3wZc3RnFyITFOxCZIKCXYhMULALkQkKdiEyQcEuRCYMXXrrMXkl6EHk7XZyvNvinUAaY7yeWb3OpTILbHXW0mj2NJ3TOcfr0xWRZBe9D7d5dxRrnUuOrx9LdxABgHqdy41W534UJfe/TerC1Qsuoc0FpQGfOPQwtV155S5qa5CEl24vqjPHk4aCfBb0Sj6v2wmOybrFRN2kltFpSnd2ITJBwS5EJijYhcgEBbsQmaBgFyITFOxCZMKKpDczOwxgFkAJoOfue6LHuzs6pKZZpx1ksNXS70lFLaidFpRc65Y8W6tW8WMWRVoi4cIV4OASoJX8vXaCPGcAqAouefWIdMgyBwGgdC57RqXrUAT1+hrpVTm/kJZRAaDR5lmAu+v8NfMg+65XT2tUlfP19aCN03LVMAukZTNiY+MAaqxWYjBnNXT233P3p1bhOEKINUQf44XIhJUGuwP4jpn92Mz2roZDQoi1YaUf42909yNmtg3Ad83sYXf/wYUPGLwJ7AWAqSnehlgIsbas6M7u7kcG/58A8A0A1yces8/d97j7nrFx/v1sIcTasuxgN7MpM1v/zM8A3gLgwdVyTAixuqzkY/x2AN8YyAZ1AP/q7v8ZTSgAjNfS0kAx2aTzas30JwLW2gcArOTFC521k8IiEgmxFVFLo+B40VttVLzQAj2syVScoN1RLZAwIz0pKjjJ2hP1evySe3LmMLXt2rqB2irwa6fqpH2sB9l8YFloWKzgZNDiKch6K4mPZZhFl76+o9ZVyw52dz8E4NXLnS+EGC6S3oTIBAW7EJmgYBciExTsQmSCgl2ITBhqwcmJiTFc+4oXJm0WSDwswyeStZqBnGRB/7KoJxeToepBhlq9HmTm9bgf4+Prqa3bOk9tC8RWBlJklCHY6wWSUdCbjUlAR2f5+j4wzQt3Xn4JX4/L1vPL+NDBx5Pjm6a4lLd5y2ZqC5Ll0A2kyIVmIMttTZ/PxvjzOt9Ir+PjMzwnTXd2ITJBwS5EJijYhcgEBbsQmaBgFyIThrob315YwGP704lx3WATvEXq1s2ePUPnbNzAd1vHp3iqbRElrpCN6Xabt6GqnCfJuPHlf+3r3kRtjz7yM2o7c/J4cnyiyc/Vq/guMslnAQDU6/yY45NTyfGHHjlJ5zx9mrfKOrtwCbWNzZ2itoMz6d34dpe/Lq9+9TXUtnnzRm7byOs1vGA9n9fbnn5uC6TtGQDU16ev4YcefoTO0Z1diExQsAuRCQp2ITJBwS5EJijYhcgEBbsQmTBU6a2o1TC1KS1BtIIMg85sujWUBdLEBDkPAGy9ZCu1bdiwidpA6o8dOfIYnXL6DE/u8EjXChJ5CtJaCQDqzbTkVQV16zxoh+XdoB4btQCnT80lx48dPUbntNo8wWfm0BPUdu70WWqbXUi3lJqb5XLpj+69l9o2rk+vLwC84sUvobZx8roAwJHH0s9tbv4cnXPJNiLXzfE11J1diExQsAuRCQp2ITJBwS5EJijYhcgEBbsQmbCo9GZmtwF4G4AT7v7KwdgWAF8BsBvAYQDvcneuMQ2oyhJz59KZat2KZ5t1FtISW7MeZKh1W9R06vhRajt5lNuYGtZqp+UdALBA1ioKvvw1C+rrBZl54yS77dxZLsl0IwlzYozaFjpcfDt89ERy/NRpXiOtQVqDAYA3+H3p3HxUky/92jSDjL2Feb4ezUYQMsFrVga31VaRlnR9jGdnlo10yysPfFjKnf2LAG56ztitAO5296sB3D34XQjxPGbRYB/0W39uwvDNAG4f/Hw7gLevsl9CiFVmuX+zb3f3owAw+H/b6rkkhFgL1nyDzsz2mtm0mU0vtPjf0UKItWW5wX7czHYAwOD/9G4MAHff5+573H3PxDj/TrcQYm1ZbrDfBeCWwc+3APjm6rgjhFgrliK9fRnAGwBsNbMZAB8D8AkAd5rZ+wA8AeCdSzlZr6zw9Gz6o3y3xwsAdjrpgpNRp6aqx2WhqNXUQofLLkzVCBQj1MCflwWZbfOB/NOa51lejx1+NDn+4CO/oHOqoP/Tju1bqK1T8jU+/lQ6Y6vd5eeKin2ePsOfcydou+SePl+jwSXFrdv4FlSz4K/nZLPB/WC9wwBsXp8uVNkNWm9tnEzLcrUiaEVGLQPc/T3ExMufCiGed+gbdEJkgoJdiExQsAuRCQp2ITJBwS5EJgy14CQAONLyigUZYGPN9HtSUXDpql7nMkitxm31cS7jgEhDFvRza9a4PFULnnOUbTY+zm3HTqb7pVV1vlYerONTZ9OFIwGgIgU4AS6lBiolvTYAYKGVll8BwLj7mBxLZ4dt2rCezrn22mup7cgTvJdat8O/IVoEff2YFNwhPQ4BACWTZvlroju7EJmgYBciExTsQmSCgl2ITFCwC5EJCnYhMmG4vd4Kw/h4Wgppd7l8VZHkn6jepHsgWwTpclEGmxOpqVdxua4VZIZVgR/n53kRy8dnjlDbmfPp4otXXHpZcC6eYffU2bSUBwCtNpeauh0ivdWWd8n12EUAoBHIiixjshbMefHuXdS2oeB+lEFGXy/I6pyaTPeBqzcCiZhkt1kgX+rOLkQmKNiFyAQFuxCZoGAXIhMU7EJkwlB34+v1GrZtTe88zp3nO7usnBzbkQSAIqrFFdR+Kyw4Zi1tawXtk1pdbiMb1gCAH983TW37H+L15HbtujI5vmXzJjrn8BO/pLbZ2XQtOYCvBwD0SO23KkjuqEeJTcG5yi5XQ84TxWDiHFc7OkErsrmg5dWZ07PUNj7JKyvPk+unHVxXZxfS69sN6gnqzi5EJijYhcgEBbsQmaBgFyITFOxCZIKCXYhMWEr7p9sAvA3ACXd/5WDs4wDeD+Dk4GEfdfdvL3asibEarnlRWgLqdYNabfV0zbVI+onq00WV0Io6P6aTxJUykNC6vXRiCgA8eew0td357/9Nbb9z7VXU9rs3/FZyvB34+OWTx6itEyQojRW8Ft6u7VuT43MkUQcAzgbJP2HbqEAuZf53Arnu0o3p1koAMPnindTWaqWfMwA0x7iP5+fS8mDkY0USg/YfeJjOWcqd/YsAbkqMf8bdrxv8WzTQhRCjZdFgd/cfADg1BF+EEGvISv5m/5CZ7Tez28xs86p5JIRYE5Yb7J8DcBWA6wAcBfAp9kAz22tm02Y2fW5ufpmnE0KslGUFu7sfd/fS3SsAnwdwffDYfe6+x933bFg3uVw/hRArZFnBbmY7Lvj1HQAeXB13hBBrxVKkty8DeAOArWY2A+BjAN5gZteh32vmMIAPLOVkhRVYN56+u1s6Ga7vZDNdi6sWyWvO38dabS7jtDs804i11qloKx5gLJDyZn55lNp2XXYptb35927g52uka/z98H9/Qud4i/95tfuKy6mtBi6Xvv9P/jA5HkloBw4+QW3/M83vJ6fPcDnvzTdelxz/6YHH+PGO8hZPN/72y6gtalFVD4obsjKFvEIhwNTGb32Hy6GLBru7vycx/IXF5gkhnl/oG3RCZIKCXYhMULALkQkKdiEyQcEuRCYMteBkr9fF6adOJG1TU1wyKFl2W9A+qSq5rRVkcoUtd4jNSXFFAOgs8HNNBHLMK6/aRm3o8eKcTx47nhz/xaHDdM61L08XqQSAex84SG1v+J1rqO2KnZckx7tBwcmX7d5CbZdv20Btd3zrHmp73fUvTY5ffUXaPwB4/IkZarvhVXytqh7PUgsuEYw1iIQctjdL63UWCHa6swuRCQp2ITJBwS5EJijYhcgEBbsQmaBgFyIThiq9VWWJ82fTFa68k87WAoCJibStqgLprce1jm6Py2HFBO/J1eum/TDjy9gI9JMdpO9d/5iBdHiOVwk7djxte8kVXMo7dYb3KNuyjkuiL7tyO7WdP5MupmmBXFor+VqtDwo2XrqZr2Ovnc7ou/6ll9E535/j63EwyMx7wfaN1DYf9GCbHE9fV/UgY7IgS2VEkgN0ZxciGxTsQmSCgl2ITFCwC5EJCnYhMmGou/Fl2cPZU+nd4jMeJIyQ3UrSAQdAmCNDW+cAcSJMSXbWreDLWKtzWxWpAsZ3ps88xXfjT59K74Kva/AFKVu8htuWKV7n78SRI9TGsjjKHk+EqQWJQafn+bydm4LL+PzZ5PBjh9PjALBlgvtxPiiHfm6cr9XpU09TW6OR9r/JEmQA1Ml2fKfDk3F0ZxciExTsQmSCgl2ITFCwC5EJCnYhMkHBLkQmLKX90y4A/wzgMgAVgH3u/lkz2wLgKwB2o98C6l3untZ9BlRlifmzacmjChI/5mfTtqDMHMogScarIFmgzuUOI3JScKqQqH1VWQXtggJZriTSy/kWf85VIIfNBq2yZp7k0lu7nfYjWnsr+EKebXE/jp9doLafP0pacwUv2jiRwgBg4ybenfxYcMyTx3irL5DXOihBR2vNtRZ4K7Kl3Nl7AD7i7i8HcAOAD5rZNQBuBXC3u18N4O7B70KI5ymLBru7H3X3+wY/zwI4AGAngJsB3D542O0A3r5WTgohVs5F/c1uZrsBvAbAPQC2u/tRoP+GACCofSyEGDVLDnYzWwfgawA+7O7nLmLeXjObNrPpefJ3nBBi7VlSsJtZA/1A/5K7f30wfNzMdgzsOwAkuz+4+z533+PueybH+PfOhRBry6LBbmaGfj/2A+7+6QtMdwG4ZfDzLQC+ufruCSFWi6Vkvd0I4L0AHjCz+wdjHwXwCQB3mtn7ADwB4J2LHqmqUHbSrYtYRhkAOJXlov443FRE5+pe/J8aFkhhEZ0Ol0naXS5RtYL6eg3SKssjSTHIAtw0yWvyeSA1zS+kX+caa+UFoGCF1QBUQcuuyQaf1+2k5zWCC6QEX/uzT5+ktvNB7boiOCYjys6sSK05D57XosHu7j8Ej6o3LTZfCPH8QN+gEyITFOxCZIKCXYhMULALkQkKdiEyYagFJ82AJslsaneD6pHkLWmsEck4wftYIJV51J6ITQsy9kJ50LiPZZDS1ymjtkDMxo+3ocl9rNX561IzbpvakL60orZWUabiJc1AOjT+ZS22HL3gegteFjSDiOn0eDHKQHGkV4gHBTjN0usRqJe6swuRCwp2ITJBwS5EJijYhcgEBbsQmaBgFyIThiq9AUBBpJdGIDPUiJ5QD9+qeJZRtwzkpEC7YNlyVXnxGU1AXKiyWeNS01iTP/F2hxePZBRB4ctul2fmWVCYcZy8npGsVUVFNoMMsEimZAlxjaD3XSSXFoGEWasHGWfB9e3keUcZbCzTMkrA1J1diExQsAuRCQp2ITJBwS5EJijYhciEoe/Gs0SIItitZJuSnU5UaC5KdODvcZEbbDe+F+wUh3kwwXttVfJaeOb8oA2iJkStpoLcH9RqPMkkWquSJDxVzmvJ1aKt5MA0NRUkBpGJ3WDnH8H6stpvAFCQ5BQg3o3vkSWJ9IJexdeRoTu7EJmgYBciExTsQmSCgl2ITFCwC5EJCnYhMmFR6c3MdgH4ZwCXoZ9dss/dP2tmHwfwfgDP9MP5qLt/e7HjMTkhqj9mRO4IckUQdEgKpaaohlcN6YPWg9ZKUSm8KH+mCmxl8OQmSPPMIpC1Kqb9AGgEa1wtI2EkSjQK9bWoFVI30g7Tl3gRXQRBnbxelHQTLJb3gheU+NIg8iXAE16i1V2Kzt4D8BF3v8/M1gP4sZl9d2D7jLv//RKOIYQYMUvp9XYUwNHBz7NmdgDAzrV2TAixulzU3+xmthvAawDcMxj6kJntN7PbzGzzKvsmhFhFlhzsZrYOwNcAfNjdzwH4HICrAFyH/p3/U2TeXjObNrPp+fbFf8VPCLE6LCnYrV+F/2sAvuTuXwcAdz/u7qW7VwA+D+D61Fx33+fue9x9z+TY0L+KL4QYsGiwW7/+zRcAHHD3T18wvuOCh70DwIOr754QYrVYyq32RgDvBfCAmd0/GPsogPeY2XXoq2mHAXxg0SMZqDZQCzQqs7RsEclazSADKaoYF2a9MUkmKqwWZEl5lC0XZrbx8/W66Wy5KNOvGbTRsiD3qiyD9kTkfF3iX/94fK3GAlkrkvMqssZF0I8pTL4LFDsP9NJaMNGp/5H0dvFfkVnKbvwPkQ7RRTV1IcTzB32DTohMULALkQkKdiEyQcEuRCYo2IXIhKF+y8UA1InMUAZZSBWTIJy/V9WD3lBRW50qaA1VsGMGvkdZSFHLq6ilUT2QjTxYEzonyDiMkrUiWbHeJNlm4BJaacHaB0/LopZdRML0oKBndLKo5Vi0jha07Jqop9eqF1yLvXYg2xJ0ZxciExTsQmSCgl2ITFCwC5EJCnYhMkHBLkQmDD3BnBU+7AU9wJiAFUpQQS+sMpA0LEp5Iv3BGvUo64ofLspEi+oyluFByfGCA3a7fD08yK4aCwptMlWx3gyyEaNbT/C6tLt8PVgh0/Egi64IKpnWgiKQnW5wDQdZjE4kzDApksl8UTFVbhJC/CahYBciExTsQmSCgl2ITFCwC5EJCnYhMmEEtZ2JNlBwuWOcSTxBtlkvkKciySsq5McywBrBe2YVFpyM+oYFxwwStpj8UwsraQY9ysKeaEHWHpnnvainX5ARF+hQURYj878MsgOjhLgo0y+iCq7HspP2MZTegpeFoTu7EJmgYBciExTsQmSCgl2ITFCwC5EJi+7Gm9k4gB8AGBs8/qvu/jEzeyGAOwBsAXAfgPe6e2eRY8GKRtIW1duqk/ckC5JnGs0mP95Y2gcAaM21qI3VOquCXfVaVAsvmBftgkd10Jpj6R1tD6rhefCe3y75GreDBBqWdxPkzqDs8XN1gtZQjSAhiokQ7aCGG23zBaBR435EyUbtQIVokPNFbb7Y9RHlcS3lzt4G8EZ3fzX67ZlvMrMbAHwSwGfc/WoApwG8bwnHEkKMiEWD3fvMDX5tDP45gDcC+Opg/HYAb18TD4UQq8JS+7PXBh1cTwD4LoBHAZxx/9Xn6BkAO9fGRSHEarCkYHf30t2vA3A5gOsBvDz1sNRcM9trZtNmNn2+FRWoEEKsJRe1G+/uZwB8H8ANADaZ2TMbfJcDOELm7HP3Pe6+Z2p8BN/OFUIAWEKwm9mlZrZp8PMEgDcDOADgewD+aPCwWwB8c62cFEKsnKXcancAuN3Maui/Odzp7v9hZg8BuMPM/hbATwB8YbEDWVHD2NSGpK0IpJUxIqNFuR3dsk1tdSJPAUBngauHTK6xoAZdEeRNFIG00g3qqkXSUEFqqzmpnwcAnaDHU6CuoXJuHGuk5c2wRRL48ZpBXbhO0CqLaVEe1IQLW5EFtkYQTc1AczSarMOPx2o5Riwa7O6+H8BrEuOH0P/7XQjx/wB9g06ITFCwC5EJCnYhMkHBLkQmKNiFyAQLa4yt9snMTgJ4fPDrVgBPDe3kHPnxbOTHs/n/5seV7n5pyjDUYH/Wic2m3X3PSE4uP+RHhn7oY7wQmaBgFyITRhns+0Z47guRH89Gfjyb3xg/RvY3uxBiuOhjvBCZMJJgN7ObzOznZnbQzG4dhQ8DPw6b2QNmdr+ZTQ/xvLeZ2Qkze/CCsS1m9l0z+8Xg/80j8uPjZvbkYE3uN7O3DsGPXWb2PTM7YGY/M7O/GIwPdU0CP4a6JmY2bmY/MrOfDvz4m8H4C83snsF6fMXMeFXVFO4+1H8AauiXtXoRgCaAnwK4Zth+DHw5DGDrCM77egCvBfDgBWN/B+DWwc+3AvjkiPz4OIC/HPJ67ADw2sHP6wE8AuCaYa9J4MdQ1wSAAVg3+LkB4B70C8bcCeDdg/F/BPDnF3PcUdzZrwdw0N0Peb/09B0Abh6BHyPD3X8A4NRzhm9Gv3AnMKQCnsSPoePuR939vsHPs+gXR9mJIa9J4MdQ8T6rXuR1FMG+E8AvL/h9lMUqHcB3zOzHZrZ3RD48w3Z3Pwr0LzoA20boy4fMbP/gY/6a/zlxIWa2G/36CfdghGvyHD+AIa/JWhR5HUWwp0psjEoSuNHdXwvgDwB80MxePyI/nk98DsBV6PcIOArgU8M6sZmtA/A1AB9293PDOu8S/Bj6mvgKirwyRhHsMwB2XfA7LVa51rj7kcH/JwB8A6OtvHPczHYAwOD/E6Nwwt2PDy60CsDnMaQ1MbMG+gH2JXf/+mB46GuS8mNUazI490UXeWWMItjvBXD1YGexCeDdAO4athNmNmVm65/5GcBbADwYz1pT7kK/cCcwwgKezwTXgHdgCGtiZoZ+DcMD7v7pC0xDXRPmx7DXZM2KvA5rh/E5u41vRX+n81EAfzUiH16EvhLwUwA/G6YfAL6M/sfBLvqfdN4H4BIAdwP4xeD/LSPy418APABgP/rBtmMIfrwO/Y+k+wHcP/j31mGvSeDHUNcEwKvQL+K6H/03lr++4Jr9EYCDAP4NwNjFHFffoBMiE/QNOiEyQcEuRCYo2IXIBAW7EJmgYBciExTsQmSCgl2ITFCwC5EJ/wfUZjIsdjFZ/gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import mindspore.dataset.engine as de\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "train_path = \"./datasets/cifar10/train\"\n", "ds = de.Cifar10Dataset(train_path, num_parallel_workers=8, shuffle=True)\n", "print(\"the cifar dataset size is :\", ds.get_dataset_size())\n", "dict1 = ds.create_dict_iterator()\n", "datas = dict1.get_next()\n", "image = datas[\"image\"]\n", "print(\"the tensor of image is:\", image.shape)\n", "plt.imshow(np.array(image))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到CIFAR-10总共包含了50000张32×32的彩色图片。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义数据增强函数" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import mindspore.common.dtype as mstype\n", "import mindspore.dataset.engine as de\n", "import mindspore.dataset.transforms.vision.c_transforms as C\n", "import mindspore.dataset.transforms.c_transforms as C2\n", "\n", "def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=\"GPU\"):\n", " \n", " ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)\n", " \n", " # define map operations\n", " trans = []\n", " if do_train:\n", " trans += [\n", " C.RandomCrop((32, 32), (4, 4, 4, 4)),\n", " C.RandomHorizontalFlip(prob=0.5)\n", " ]\n", "\n", " trans += [\n", " C.Resize((224, 224)),\n", " C.Rescale(1.0 / 255.0, 0.0),\n", " C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n", " C.HWC2CHW()\n", " ]\n", "\n", " type_cast_op = C2.TypeCast(mstype.int32)\n", "\n", " ds = ds.map(input_columns=\"label\", num_parallel_workers=8, operations=type_cast_op)\n", " ds = ds.map(input_columns=\"image\", num_parallel_workers=8, operations=trans)\n", "\n", " # apply batch operations\n", " ds = ds.batch(batch_size, drop_remainder=True)\n", " # apply dataset repeat operation\n", " ds = ds.repeat(repeat_num)\n", "\n", " return ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义完成数据集增强函数后,我们来看一下,数据集增强后的效果是如何的:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the cifar dataset size is: 1562\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "the tensor of image is: (32, 3, 224, 224)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO2df6wsZ3nfP0+uB64nuV7wXcArc09tHzlYBKkOQRCJlKZNkwBK61A1FFQlDkV1kEBKpFSKSaIWRYqUpCFRorRIjrCAiPKjIgkWok0slIj0DwjYMdjEGHyMc3zxYnevk7m3nXvJ3JO3f7zvu/Pse2b37Dn7Y/aceT7SnJmdMzvz7u6833ne533e5xXnHIZhdJfvaLsAhmG0i4mAYXQcEwHD6DgmAobRcUwEDKPjmAgYRsdZmQiIyOtF5DEReVxE7l7VdQzDWAxZRZyAiJwCvgb8MHAe+ALwVufcXy/9YoZhLMSqLIFXA487555wzv098FHgjhVdyzCMBbhmRee9EXhKvT4PvGbawSJiYYuGsXpGzrkXpTtXJQLSsG+ioovIXcBdK7q+YRj7+ZumnasSgfPAOfX6pcDT+gDn3D3APWCWgGG0yap8Al8AbhWRm0XkecBbgPtWdC3DMBZgJZaAc+6qiLwL+BPgFHCvc+4rq7iWYRiLsZIuwkMXwpoDhrEOHnDOvSrdaRGDhtFxTAQMo+OYCBhGxzERMIyOYyJgGB3HRMAwOo6JgGF0HBMBw+g4JgKG0XFMBAyj45gIGEbHMREwjI5jImAYHcdEwDA6jomAYXScI4uAiJwTkT8TkUdF5Csi8rNh/3tE5Jsi8lBY3ri84hqGsWwWySx0Ffh559yDInIGeEBE7g//+23n3G8uXjzDMFbNkUXAOTcEhmH7kog8ik81bhjGMWIpPgERuQn4XuDzYde7ROTLInKviLxwGdcwDGM1LCwCIvJdwCeAn3POXQTeB2wDt+MthfdOed9dIvJFEfniomUwDOPoLJRoVEQy4FPAnzjnfqvh/zcBn3LOveKA81iiUcNYPctNNCoiArwfeFQLgIgM1GFvAh456jUMw1g9i/QOvBb4SeBhEXko7PtF4K0icjt+2rEngZ9ZqISGYawUm3fAMLqDzTtwKLK2C2AY68FEYBpV2wUwjPVgImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxTAQMo+MsklkIABF5ErgE7AFXnXOvEpHrgY8BN+GzC73ZOfe3i17LMIzlsyxL4J85525XWUvuBj7jnLsV+Ex4bRjGBrKq5sAdwAfD9geBH1/RdQzDWJBliIAD/lREHhCRu8K+l4QZiuJMRS9O32TzDhjGZrCwTwB4rXPuaRF5MXC/iHx1njc55+4B7gFLNGoYbbKwJeCcezqsnwX+CHg18EycfyCsn130OoZhrIaFREBEvjPMSIyIfCfwI/jJRu4D7gyH3Ql8cpHrGIaxOhZtDrwE+CM/GRHXAP/dOfe/ROQLwMdF5O3ALvATC17HMIwVYZOPGEZ3sMlHDMPYj4mAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxTAQMo+OYCBhGxzERMIyOc+SkIiLyMvzcApFbgP8EvAD4D8D/Cft/0Tn36SOX0DC6zgAow1It//RLSSoiIqeAbwKvAd4G/F/n3G8e4v2WVMQwpvED1CIQheBootCYVGQZ2YYBfgjYcc79TUg1ZhjGsijDOgPysB3XFT6B3wIsyyfwFuAj6vW7ROTLInKviLxwSdcwjG5ShSWKQA/fRBgA20B/sdMvLAIi8jzgXwH/I+x6H75otwND4L1T3meTj2wYp4FTbRfC2I+2BHph6QNb+Jq2HfYdkWU0B94APOicewYgrgFE5PeBTzW9ySYf2Tyi/yk2Oa+0WxwjEtv92hLQS/zRHuZIjsNliMBbUU0BERnEKciAN+HnITA2nDPADfj76TL7fU8mCC1SMukLiJbAIKy1cs81/9ckC4mAiOTADwM/o3b/hojcjp+j8Mnkf8aG0mf//XQtkw+aKAp7LZXxDP6+/1pL12+N+HTPwxIFYIBvEqQ/0CEdhQuJgHOuBM4m+35ykXMa6ydaAWlzIC7XhbUWhEodt2pROBXKdg4YbEO1A99Y8TU3ilJta8fgFrA9gHK4/0cbzX/6ZXURGscY3YzMw+s82R/vsfRei02HgtUJwl44fw4w9N7mTqG//Cp9HX4l7S/oYyJgHJ0shzyDLPPbWQZkcE0OlwsoR1AWdaXX1sJl6v3L9iFc8sWgKjvqnyiTZSwIU0RgwNxqaSJg1OS+4vd6cG0P8hzyHvTyjLyXU4wKihEUI7gc1gV+6YV1KgaXkkukXZCHsRyeYyVRs8eDRgHA/wlCPSECTdGEU6wDEwGjJvOV/to+9PvQ62f0+316/T5n+30uDocUwyGjYUExhCKHYggXKn/PxYofxeAi/p5MmxV6nYrEQRz2+BNDGjI8yxKIQmAiYByW2ATo9aE/yBkMBtwwGDDYGpANBrDbY9Tvcba3y4V8SJHDKIN8F8qqdhym1gHU9+xVJh9kna3URyEVANgvArH3IO7TPNx8WhMBoyb3zYBeP6M/GHDD9hZbW1uwvQVb56Dfo7+b088zRjmMsqG/7yrvoC6qycofBSF1KlbUDsUzmBDMTdoU0E96bQnE16kITMFEwKjJIOvhzf+tgReA27Zh8DJgG7ai1xD6WUVOSV4V5GWo/Lt15U99A+mSMRkDY0IwB00+gTimQIuAfj0HJgIGV4DHgGzHm/VlNaSsSi5XJbdUJb2qgK0RjJ6AnV3Y3aEY7jAaFlwsoCihLCe7s7Mp21XD6846+w5L7P/fpf4S09DOtB93DkwEDMB73h8Byt3Q/18VFNWDlFXBLdWIQVVQ7e5SDHcoh7tcHFa+d6DwXYZVNVmhY5yBvlerhv3GIRkxOYy4yTJIva8HYCJgjHkOb75f3I0WAZTVDkU5oqhGXB4NKYcjqlGIFwgxA2UFmbrhsmQd0darcURiWHAUgEMEBU3DRMCYYA8fmz929FVQlAVF9TDVCKrCL6jtqoCs3N8U1b6paAXoUbHGESmBHbwAaB+AVl9zDBqLcp4QDFR6IbgYnzzB05clHr9r8ZU83ne5WqIAXOZQ96ZxELEvNhJ/gJzJH+MATASMqVwCvlTClx6E6x+s7604ulBX9D4+BiBW8hizcpbJiq97B4wlE82sKAwmAsYyeS4sZ6ijAPtMWqDRGshz6GVwNoN+DtXQOw6b8mReVtfo5JiAVWKOwY6TsZK+t0tM9gBEq3PsD+hBL4frcuiH7SqDq8HZGAOFYni77sWqaC9XQZeZK8dgSBj6rIg8ovZdLyL3i8jXw/qFYb+IyO+KyOMh2egrV1V4YwoZCyefnMUV6iAgfck8D1ZAH3oDv7AF/QGc3YJ+5ot1ljq8vYcPLIrNB2P9zJto9APA65N9dwOfcc7dCnwmvAafc/DWsNyFTzxqrIsoACuuUZeoQ3/Bm5R5GIF4Xc8PQBonvohCMPDNBC0E0bcQmxKnV1tso4G5mgPOuc+KyE3J7juAHwzbHwT+HPiFsP9Dzs9q8jkReUGSd9BYJXoU2YqJ7fuMevBRHH6c6RRY1Np0GR+Q1MOPMtTh7lep/Q3WLFgfi/gEXhIrtnNuKCIvDvtvBJ5Sx50P+0wEVo0eTx6tgWLmOxZCD1nPMj+sIO/55gDaEsAf2IsBSGUINKKOddHBbzk2lmCdrMIx2DQF0b6U4iJyF765YCwLnYZ6TSIAk5bAtT0vBBOJMFU4az+IQFFBr6if/D11mInAellEBJ6JZr6IDIBnw/7z+JyQkZcCT6dvtnkHVkBT9pkVsqcuVxQ+uOhizycb6ekgghjUEnoHyIL/oPAJTnV8i3YQPoOxDhaZgeg+4M6wfSfwSbX/p0IvwfcDhfkD1kTTaLIVE7XmMn4cwWjkl2qEbwAO8Ta/Kk+W1U7EPv6JcY56Qp3b8FNcn1l98Q3mtARE5CN4J2BfRM4D/xn4NeDjIvJ2/JCGnwiHfxp4I/A4/qd/25LLbEwjHVa6BhG4qi5ZlN4auBD8A4MYPBAtgdIHDcVmQ/RdlhX0yv05Bwrgr1f/ETrPvL0Db53yrx9qONYB71ykUMYCpEKwhsvFS17Ei8BIZSruh+whVVUPN47NgZjJuCpDIFFVb18MWYouYM2CVWMRgyeNfdloV3+5dM6LbKQqeIgrnhhqHDIYUXmLYSwQFVwN5S52/LlGmAismm6LgE7OcFLS27RoCRSEnoIKP9Q4CkAG12TjzGT1vAYZk999VefNHAX/whB4CSYEq6TbIjBgMgtLmpVlzvRMG8caRU1rThQB8DuqMDtWLwQREZyBmXqdljNaDLeMYDjyASc7mAiskm6LQNPEjSdhjKtOQbViYlxQHEgUhwlXwMUSRjt+JGGv7wcUXa2g6nlfQQ77v+8kLdFJMdA2mW6LQNMdFqNVjrsYrDBIKHIaPwYgTXOfXvpiCWd3oVIH5lWDCOjtaCkcV2vsGNFtEZjGSUiEtwYRSCe7SXMF6O0SgtfQNweuK9XX3JQaK699CldX/1E6jYlAE+NRMZg9OoNoBdyA78rTDsKC0GVIkh6/8DEBvSokJdHJCCYSE3ixWIeYdR0TgVnoLJnGBGeo/QFNlsAFvMtljyQJSQn9kKG4xPcUTGQoTSyBk9Rxs6mYCMwieq+PkMv9pKObAn3qB3a0BIbUw4GfI8k7WMKgUhaCFoA0XZGxckwEppEmz2+6Iad5toeceDM2NgXiOjT3Ad8MSPMFPqOOvSFGCOKf9pm2BMKSxWwja+AU3c5fYCIwjSHNbdWmnO7zCMYJ4nommwJn+zAY1SnHYvv/G+o9340fGHQbsL0NW9uwtcX+JIX4E5Qx1BgzvlaNicAstEcrvVmbhCC+jsee0Lt3XPnxMQD9QRhFOJr0qWZ4v0CYztSPDrwNtm+Dc9vUCUdSqslW2An9GjcGE4F5SO/Epil2OiICOjZgAAwG0N8KTrwcGHnnX8wNcBYvAC/DV/7t27wQ9LbxU54XhR81VE3W/MuVWQLrwkRgXvTdmFoC6esTHOgSrYAb8MlEz4YMQoPg+ct6kBeQj4IDEDiX+8q/tZ0IwNaWHyBQFCHnWC0GVaX8Bu193E5gIjAvTQIwo497XU6tdROtgD6+GdCPacRyGJSQ9+G6wncDjgoYFL7yn9uGwTbk25kXgO0tb0LE6KFRETKRFLUAVGsbB9VpTAQOg+7wjutpInACnYPRITggpBDvQxbzCIaooF4IBhqUPsnIhRJu2IL+NrCV+6f/VhSALZ90kBAQUPrnflmVVFU9O5GJwGo5UARE5F7gx4BnnXOvCPv+C/Avgb/HD/J6m3Pu70Ja8keBx8LbP+ece8cKyt0esyr9CbcEogCcxTcFxlbAgHpKoWC/55X3DQwqvEgM+r7yD7Yg3wpvOgeEbKMl/k2UVFXpT1WdyBbVxjGPJfAB4PeAD6l99wPvds5dFZFfB96Nn3MAYMc5d/tSS7kpFHjJi0PlmnwC0SXeNELxGHMKP1NQTsgFoP+ZNty1b0R3F0BIGJAkQizqpSon3QProMsxAjCHCDRNPOKc+1P18nPAv1lusTaYODd8pEkMRumbjj8xdmf8cXVIdZMA6O2oHGPU0KLxTKUVFD5Zaak6DIzVswyfwL8HPqZe3ywif4UPHPtl59xfNL3pxMw70BH3dZzGIKfOEATs7zWhaTubtASqCp895HLITlqNI43K0g891sFCxmpZSARE5JfwIz0/HHYNgS3n3AUR+T7gj0Xke5xzF9P32rwDx4fYFBhbA03NgWkRlJl6U+z7qyrI9jcFisJrQllmVFVllsCaOPK8AyJyJ95h+O9ChmGcc992zl0I2w/gDefvXkZBjfZIJw7NdBMopWEcgB8qqIh9gBS1CISn/+XYJCibDQxj+RzJEhCR1+Mdgf/UOVeq/S8CnnPO7YnILfiZiZ9YSkmN1tCzms3VHNjXa5I0B8pgCVR5GFNcURYVZVnVTsGyMy2t1pmni7Bp4pF3A88H7hcRqLsCXwf8iohcxTtd3+Gce25FZTfWRJxGfNw7gM8ePCEETQIwtgJSS4C6lyBYAmMroMwoy4zL1hZYGxIs+XYLYT6BjeUfA68EXo2KEgzhwv0B9RTkTTETOvuoSklchkCiooS/ethnFb4w9Os4c9kFJnMSzMPpsK4O+b4O8YBz7lXpTosYNKZyPfXEwgN8xuBeD67NkyZByf704TpGwEcDUxShwhf1vAKPfBWGZT1l4bfwSUiOgr5kbEqYGByMiYAxlRuoIwQHA7iu5+cLyHuq2z/WthhAFScf0H6CEQyH8K2w3h2F7RE8SP3UTxORHIbTTCaC0pxoIVjCaFUTAaOR09QDhbYyONf343yyaAXo8RFp6rWku3A0hG8NYWcXdnfhiZHvNtpleROOzuqwgBMsBNssnMnKRMBoJA4JiEOGB3qGkXSA1AFu/NHIV/6dHXiigK/iB5csa1ahU6juS2rDRBfvRBLGYI05ohCYCBiNnCUIQbACxs4/krWuYVW9qsJGRXD67cJTQQAe4ejt/iZ0Z4Qe6Fmq/8MJswZ61A6bSKp+c2IiYEwltfKB6TkUtG9ABQaWlXcIFkU9H8EqcrDqsqZDlBbxNWwsBd6TGrtTRhx5yKWJgDGTCsYzAQHTuwLH0wz5N8XRgBdLLwAXw/CAy6zuiVzhY9i1EJxIAYjE3O5xDvcjYiJgzGaWAOjQYJgIGqqqEA8Qw4Cp5yRYNtolUSXLiSZOPDtc7DRHHjtgrIkWA+djRdo3F6AWAD2wQPUWVCECsBj5mIBVNgV0eXW5T5QPYBoLCgCYCGw+LWUoGj9Fs8ntxgFCTSIQLIGy8E2AODfhqs3zKAAnuhmgWYK5YyKwyWxivkItBNEK6PX2lbMKAqAdgqtOFab9Acb8mAhsMmmbuwUqmDE4CB9CGEVAz7cQHINaANbRFEi3jYMxEdhkWrQEoglflCHOX9fmdIxAsmQ9rwu93qTLYFVatkfdY1bQEV/AEjER2FT0k7UFIYh1/mIFoyAEpcomDKFMWTXpFwi1vteDfljO1rtXRhSCzvgCloiJwKain65RCNZINN9H+Hb9hZD6a6oIJD0Fec8PONLWwHX4EN9VYRbA0ThQBETkXhF5VkQeUfveIyLfFJGHwvJG9b93i8jjIvKYiPzoqgp+omkwsddtCVyitgaiJVDogKCITh6SiEGv58cd9PqThoKxWcxjCXwAeH3D/t92zt0elk8DiMjLgbcA3xPe899EZJXifzJJK/9BQ+RWxEXqJkERswDvswTCtrYEogj0lTWQ1UJgbBYHioBz7rPMP97jDuCjIeHoN4DH8UlpjHlpagZMJPxfH2NLgGAJ6CZBZCwC+y0BoiUQmwbMKQI9Nqtb9ISziE/gXSLy5dBceGHYdyPwlDrmfNi3DxG5S0S+KCJfXKAMJ4vY997UFGihSRD9AgVwIVgDVWoJ6HkZtQCEzKR5r85IlGfeLzCVDD8q7jb8OPkBZjqsgaOKwPvwP9Pt+MDF94b90nBsY/5A59w9zrlXNeU86ySpADTF6K+5Qb1H3SQo8DlBL+hBANFMYDckDaQezKJSihUhs3hZHRAw1JuymBCslCMNIHLOjfNBiMjvA58KL8/jZ5mMvBR4+sil6woNXWwTT/5M/a9irdOc6fpelMFJWPqpxycigVTFj2JQhRRio+hPoGEcQiRHzXnO/qHKqx540GGOZAmIyEC9fBM+TwTAfcBbROT5InIzft6Bv1ysiCeYjPqmT598qVMwdbytifjgH+GtgCKMDJx0GDQsBQxj12K0BJgRzTfNCkgXY+kcdd6BHxSR2/Gm/pPAzwA4574iIh/Hp467CrzTOWfdtynR8TWtKzAVAZ1Bs6/2rSEg/wqqSRAChy6UMCghH4XypFZAiDAcZxUO4tHUwwjsF8ImS0BbRSdwwtc2mWdW4rc27H7/jON/FfjVRQrVOvGmW2YFa2rvp23/WUtsF6eWwYJJJuehwDv0dBhxUUAeLYEeE1ZBOfJ5BUehOXBR+QP2WQITA5GoxWBWYcIMZsZy6HZSEf3EaaqIMX3TojdczAU3bwVPt/X/YxafNJ/cCitFibcGdJPgQgGDKU2CUZhbYBgtAWqjZZ+uphaANvvThCYRnT/MWJhui8BrmP30HVI/bXePeI0Bvh8ligDJNWDyBm96DbVTsKIxkw+wMiGIcfkTzsFQwfsxu02o6aMCvhUmGRkVfk6B+N7LNNTbVAC0YzB+nnQ7ikBbzYLoxFwgr98m0W0RiJWqqVJqj/QyK1dThT+o/z/e+JnajhTU5vgKrYFLhElCKl/x+zn0RtDL8RORDP3EIk8MYTdMMPJU5ecXGOJnFmpMMa5FJPo5ouBBs7Wk948wi2BBTATSbX2TTbVhD0FTBZ/m9JpVvmmmcRSAHisftD/CjwgcFnDdKAT/ZLCd+Yr/VBCAnaGfinoXHzk2xIvI1JNG/4j+vtPhyroppL+3Hgsn2uw63RaBeEOlT5JYyeLTddlPmlkikFZ8bfI3CUHJpACsUASeo54yrFf48QB55i2Bp2LlH/rKv4MXgG/Mc+LoXIyfQWUt3mf9aAerFsB+eN+Q9VkGB1lwx4Rui8BBuahW0QWXVmQtAlWyrY+f1WSJFWgN1sAQP9NPL3QRZpmvjzsjLwBxerEn8JFjc6EFTFteqXNUb2sR6Ktz9EIBVt17cEIEAEwEJtfp/lVYAZG0dyAKQNWwPa0LMVYEXQGazOklEn0Du0BWQjaErFpAACIx5iAKQZMAxLUWPN1TUqpjdliPEDRZkscME4FUCPS+VVkB07oCU0ugyWpIK0RqBayhWTDURSjhmhHsVF4AjjzHoG4KxMoMk21/LXrpBAMVnK7gSqyUq/wOmppsxxgTgXSmimXOXNHk9GsSgNQn0NQkSAUgLk1WgK4oK2CPWgjAXytaAAtNMqpFQH9H8bPqQCL1e50B8sr/a6eCvfidDFlNF960rt150NGfG0K3RaDJo6wr5aLNAW2mRudXvAH0dnpTNDzlJpoJ08qsxSJ621fEFepIQh0MtBDpZ40WQdP3kLwtrveaLKxlox2Q894f2ppJP2PLgtBtEdABKPpJm7Y/F/mhomk67ek+q/Iflnm6HZdINEBy6vEFC6ObA6kQ6GaCpql3J/6GqxLCwyheFIC46OZMzv7mzZrptghE9A+ih7RO8xnA/An19THRnE1/7FlPu2kWQSR9+ushx4vGOBzAFbz/LT4YFx4pVqolCmcUYr0dmdbU0t9H26RDpGNPRpNTM1qMaxYCE4GUNIHHtCCdGDOvY+ebKlxqyscbOVb8Wcz7dEhv/jVaA3v4j77UoaJpMyB9DZNWVJZ8Tel30aa5rXtxdDyDFjvNiptxTZgIpGjTrWlMf9zWc8Pnat+0H1A3C3S34KyFZD3rZm56Cq6pL3tpuf51G3mWT0AzyxrYhC48baFpS0DfC6nVYyLQMrES9dlvYudwOoc8h+fi4KIYnKLb/NMqbJMJOOsGTYVgVpnjOm0aHEcOEsdMrZtIv4e2BvnoMmhLIFZ67SvSMRFrLvM8SUXuBX4MeNY594qw72PAy8IhLwD+zjl3u4jcBDyK7y4G+Jxz7h3LLvTKiO2xWEl133RYX+nBlTjgRbvFm8zUJmb1BKSkN8i0mz6+t6mNedyIYwm0VRA/S3yKJk/6K5n/TUr93cffbUBtha2btOs3jYXQD4L0nlgj81gCHwB+D/hQ3OGc+7dxW0Tey+RXvOOcu31ZBVw7BfBVahFoSnqpk2ikIgDzC8E0p6M+R9q+n9bW1+c4ziKgnWNakHUMRBozEdZXtONQW3RQxwys8zvRTbImEWh6AGxi74Bz7rPhCb8PERHgzcA/X26xWqbEC0FTxhtt0qWqDpOVtekHbbIC0v2RWUFF6Tl1hdHr40h0sqYOtKYcjGl6tvi5owBox5zuqVk0BkTndNDE12myWC0Cqe8jPc8GWgKz+CfAM865r6t9N4vIX+G7jn/ZOfcXC16jPeLNEj22UQiaunjS9upBzHL4pTEF00jNzVQIjqsIQLMA6MCEPNnWIcVQO3jjOjYJogUXLY55xaDBPzTTb5E6Z+PvQsOxLbOoCLwV+Ih6PQS2nHMXROT7gD8Wke9xzl1M3ygidwF3LXj99VBSj0yL7VKY/mPO8krPagY0ebr1+ZLlVAV70/wLx7Ep0IRO7BKFQFd43UTTzsL4PfUazqXP02QV6O3EJzTh5JvV36+bKqj/b8jTX3NkERCRa4B/DXxf3Oec+zbw7bD9gIjsAN8N7JtlyDl3D3BPOFfjBCUbR1P7P+UoloDe1xSXMKUpkGUhTLap7bwhInCKJcUQ6MqVVsAoAKllkDYZ4pM/JiLJ1Rqaf7cmf5AWgSJZp118aa/RBrKIJfAvgK8658ajRkXkRcBzzrk9EbkFP+/AEwuWcbOoqLsED1pmnSNdH2QF6O2sfnml6VwbYgmcYjJKdiliUDHpM2iqXHqEpfbnxB6dVCxmZSVK/UH6tbZSdMp1Lczp75Jafel6Ex2DTfMOOOfej599+CPJ4a8DfkVEruJ/83c45+adzPR4cVRlT5/a6fiFpu3YxtTnyMPl03O1EGzSRBSAWN+0xbw0MWhap4FdoaKf6iVNp7QPv8lHUzGZADUVgVmBSk0O2ibzP72HWhDvo847gHPupxv2fQL4xOLFOuFoT3HTE6FJBPQSHq9VvNlSJ1rLIqAF4LoMqmp/139clhZtqNFCEAqS90Kew6aAKu0TSAuYZkJORWCa9ae7K9OKPU0AovNyzVjEYBvogJe0/dhkETQ5HvMwbr5paZmxAOTQ63kRiMu1FVyt6o906AxEmvQ7ieukkp/q+SjPMvpQUitAd9ulzYx5REBfP24Xyb74usmpDPt/4zViItAG8YbT7cBpAqAdjU3m54aJwBl83YoC0O9PikBcCOthecTmwbQKo7+3UNHzPNT3DIo+7MX2Sepg1IFKOkoxVny9Pc3vk+6LlXtad3DBuxUAAAwhSURBVO4G9OiYCLRF+kSYJgDTbrQoIk0+gZY4zaQA9PqTlgDsF4H+7oLZiDQN1sCpHPLMJ0cdW/x5CP2G/SKQCusUATgDXMrq1xPX1+gK3iQEGxDTYSLQJtp5NE8vQ5O1sCFWwCnqp22vB3m/FgJd6dP12QzKasa8BIclsQKyXv0y/XdJEAMtCHEdfSs6D0BoAsTPWhFClfvq5JFYseOFpgnABjTfTATaJu02OqirMW0mpIEqLREF4Lq+d8KNLYGQnCVWfBIh6PXg4uiIvQYHNQmCFaBDBjK8VXCVya+vIDQTBkwO7dUiEIg+j8gVJv8/Llu8gL54+v8NwERgk2hyAGYN+/R26sxqgVP4ypb1w5M3rhvCayvV9q0AsmYreiZp5Rwk27GWZnU9HLG/4jf2UDT5WJJ2farb4/frpkEkfeqngUkt/m4RE4FNpulpsYFNgTz3T/9rmwQgjbEPFaGK5Y1m+7wOwjgwaECzAESzPTx9myr91K7JdFBY+iYlApqJqEgdPzBv703LXbomAseBeOPopkCMfW9TBHI4Hbrfsryu/Hmv3r6mB1d1kyUKQFxn4b3lATEDaWKOVATiWsX4n6I29aNBtU9o9BN5nrH+ydsI5544bxSqtKmWBnKlfoOWMBE4LujgIi0EbTQFkv73PPdWQLQIsnSEH0xWNiUEWToMuIk0Yk9X/AGc6kM+8EuVqYUQJZgllXRa95wWgqZjko8C+5v5ExZBFIJ0mrSmMqTNvjViInCcSHMUrtsfkDFRufNQ8a9TYjC2CHQF10EyMPaWV/jKOhaDJhFI++nDcnoAWaj4eR+uDesSP4a9xPc67EFzhZ4lBPoJPsMS0H4Bku2xs1B3NerApGlLC5gIHDeiEOin5zpMyWQo7Wn15NcCMBaC2NZXlTt9wFL6/18T1qdiFGREVfpTfe9vyNVT/9pBvR0DgkaoJnkaR3GQAMS2QyoKSQVt8l2kzv8rcWdMQKMHUswqQwuYCBw30ifWuqwA7XAL7f685wODxkIQ2/dqrQc+ZYTegVBjKmXVZJW3IPZihQiV5lToZpx44qdL0g1YAZfD9pX4faVP4GmiMEssFKkVkK5LghBEK0YLQTxBg8O0DSEwETiORKfgOi0BmLhpnyuT+7cMVkBY52UtCpS+8lehgo2jB/FP66tREHQsv6rVsX1f4W9YHX5cKq99PCfhuAzqJy/sf6qnT99UTJuEIKtPxeSlJ4Rgn5MzCl4csahHVKUnWjMmAscVnbZ61aQ3aLhxL2W+YpdVEIAKrs2CEGTeEsgbzN9KVcQygypX+3K11kE2QQyyKtSdcHxe7X+IRyaiAuN5dH+9juSrku1pvoLwJI+V/Aq+F2LaVzXemTcs8fqpGbFmTASOK9qTvWqijZ0+RQtvbl/JoahCM4DgG4jbWX2TZeF943s+WAFjIdDX1JUlm7zsNdGSqKBUFVM72LVhQRbCe8eFoBYC/QZdAH1BLQQ6VDCwN2V7jLYCku7SxmjCNTNPUpFz+HTjNwD/ANzjnPsdEbke+BhwE/Ak8Gbn3N+GDMS/A7wR/1F/2jn34GqK33HWHWSiRUA/TSs/RHcvC0/IEp4LN/XpTIkDajuIQ6l9A7Gi6F6IaAGkgTqhKZBHIcj3i0A0KAjrIlc+h/gZ0jc1fd5UCHIOj1alJgHYZBHAR1v+vHPuQRE5AzwgIvcDPw18xjn3ayJyN3A38AvAG/BpxW4FXgO8L6yN40raGR5J4xbYv32F8BQOy6msduRlTVZANMvzyXWVmOpZWI+FgP0P9niKa6kdhWUGZa66DnN1zfTp39QcOIoIxMJEh2f0eWyAAMB8mYXijHs45y6JyKPAjcAd+LRjAB8E/hwvAncAH3LOOeBzIvICERmE8xjHjfSG1xUlU+tI07aq0Hu5HzFYBjGIlbvKkz790GU49gckxSjxlZ9gCcSkIRpdv3TdK/LgK9DOvmkVMRUC7bQ8DNOEYAPE4FA+gTAJyfcCnwdeEiu2c24oIi8Oh90IPKXedj7sMxE4Lkx7sjcJwLT3a8aN8/r1HsGdEY4dZ/zRDkL9ONfNgcQKiNuxqaGLnb6OLYAi7LiSHqSZZhFoC2ge4rFROHTa8w2wBuYWARH5Lnz+wJ9zzl30Tf/mQxv27UspfqzmHegiTSZy+v+UtDIdUKn2ZgXO5EEY8joT0JVw/CnqCh+t8wz8WIV8uhjEaMKigHIEV3S24OhkjZmMpzldm5RlWkVOg43SZsW+6Kl2mEsERCTDC8CHnXN/GHY/E818ERkAz4b954Fz6u0vBZ5Oz3ks5x3oAk3t/yav+UHnOOjplvbrNYjARF+6Eow9aidkHCQ0vmQZRCCDa7J6O8M3G8oRlAVciinC9XySZbJ9UM/LNEGY9vl0z8AB4xPWyTy9AwK8H3jUOfdb6l/3AXcCvxbWn1T73yUiH8U7BAvzBxwTpjkAD0va1i6T/drKmBZLn/apN1WUIAiXCNZBaBLo7skYiJRlcHlUiwBNIlAl2zqQZ9rClO1ZApdaAy0zjyXwWuAngYdF5KGw7xfxlf/jIvJ2/HQcPxH+92l89+Dj+I/4tqWW2FgN+uZdFrrCx2vofbOeiLFPfp6nZbQO4iFBBCbEgFoA9vS08nG76Wk9zWk4SwDSzzlLBDbACoD5egf+N83tfIAfajjeAe9csFxGWyxbCKB+wubqdaw8qSWQvk9bA72GY5Ljx2KQJf7FLPgBYqXXyzRrBKa3/2nY1mvdlErP2eQTaBGLGOw66c29SvO0TNbzHB+PTUYxTp0fECbE4Ap+4tac4AeIy1Btz2JEnXcwrnW5ZjlDU59KXOspyzZADMQ/uNvFHIPGQsTutyZLQb8+qMLPS85kYtFsyvY0h2rsgVg/DzjnXpXuNEvAOP7E5sa6wqhLvBfshPAdbRfAMIx2MREwjI5jImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxNiWfwAj4fywv7UMb9Dne5Yfj/xmOe/lhtZ/hHzXt3IjMQgAi8sWmrCfHheNefjj+n+G4lx/a+QzWHDCMjmMiYBgdZ5NE4J62C7Agx738cPw/w3EvP7TwGTbGJ2AYRjtskiVgGEYLtC4CIvJ6EXlMRB4XkbvbLs+8iMiTIvKwiDwkIl8M+64XkftF5Oth/cK2y6kRkXtF5FkReUTtayyzeH43/C5fFpFXtlfycVmbyv8eEflm+B0eEpE3qv+9O5T/MRH50XZKXSMi50Tkz0TkURH5ioj8bNjf7m/gnGttwc8juQPcAjwP+BLw8jbLdIiyPwn0k32/Adwdtu8Gfr3tciblex3wSuCRg8qMn0/yf+KnoPt+4PMbWv73AP+x4diXh/vp+cDN4T471XL5B8Arw/YZ4GuhnK3+Bm1bAq8GHnfOPeGc+3vgo8AdLZdpEe4APhi2Pwj8eItl2Ydz7rPAc8nuaWW+A/iQ83wOeEGYgr41ppR/GncAH3XOfds59w38BLmvXlnh5sA5N3TOPRi2LwGPAjfS8m/QtgjcCDylXp8P+44DDvhTEXlARO4K+17iwjTsYf3i1ko3P9PKfJx+m3cFc/le1QTb6PKLyE3A9wKfp+XfoG0RaJrt+Lh0V7zWOfdK4A3AO0XkdW0XaMkcl9/mfcA2cDt+mtH3hv0bW34R+S7gE8DPOecuzjq0Yd/SP0PbInAeOKdevxR4uqWyHArn3NNh/SzwR3hT85loroX1s+2VcG6mlflY/DbOuWecc3vOuX8Afp/a5N/I8otIhheADzvn/jDsbvU3aFsEvgDcKiI3i8jzgLcA97VcpgMRke8UkTNxG/gR4BF82e8Mh90JfLKdEh6KaWW+D/ip4KH+fqCIJusmkbSR34T/HcCX/y0i8nwRuRm4FfjLdZdPIyICvB941Dn3W+pf7f4GbXpLlQf0a3jv7S+1XZ45y3wL3vP8JeArsdzAWeAzwNfD+vq2y5qU+yN4k7nCP2XePq3MeFP0v4bf5WHgVRta/j8I5ftyqDQDdfwvhfI/BrxhA8r/A3hz/svAQ2F5Y9u/gUUMGkbHabs5YBhGy5gIGEbHMREwjI5jImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcf4/ZU9FRHDzNy0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ds = create_dataset(train_path, do_train=True, repeat_num=1, batch_size=32, target=\"GPU\")\n", "print(\"the cifar dataset size is:\", ds.get_dataset_size())\n", "dict1 = ds.create_dict_iterator()\n", "datas = dict1.get_next()\n", "image = datas[\"image\"]\n", "single_pic = np.transpose(image[0], (1,2,0))\n", "print(\"the tensor of image is:\", image.shape)\n", "plt.imshow(np.array(single_pic))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "cifar10通过数据增强后的,变成了一共有1562个batch,张量为(32,3,224,224)的数据集。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义动态学习率函数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义动态学习率用于ResNet-50网络训练。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "\n", "def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):\n", " lr_each_step = []\n", " total_steps = steps_per_epoch * total_epochs\n", " warmup_steps = steps_per_epoch * warmup_epochs\n", " if lr_decay_mode == 'steps':\n", " decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]\n", " for i in range(total_steps):\n", " if i < decay_epoch_index[0]:\n", " lr = lr_max\n", " elif i < decay_epoch_index[1]:\n", " lr = lr_max * 0.1\n", " elif i < decay_epoch_index[2]:\n", " lr = lr_max * 0.01\n", " else:\n", " lr = lr_max * 0.001\n", " lr_each_step.append(lr)\n", " \n", " elif lr_decay_mode == 'poly':\n", " if warmup_steps != 0:\n", " inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)\n", " else:\n", " inc_each_step = 0\n", " for i in range(total_steps):\n", " if i < warmup_steps:\n", " lr = float(lr_init) + inc_each_step * float(i)\n", " else:\n", " base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))\n", " lr = float(lr_max) * base * base\n", " if lr < 0.0:\n", " lr = 0.0\n", " lr_each_step.append(lr)\n", " else:\n", " for i in range(total_steps):\n", " if i < warmup_steps:\n", " lr = lr_init + (lr_max - lr_init) * i / warmup_steps\n", " else:\n", " lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)\n", " lr_each_step.append(lr)\n", "\n", " lr_each_step = np.array(lr_each_step).astype(np.float32)\n", "\n", " return lr_each_step\n", "\n", "def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):\n", " lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)\n", " lr = float(init_lr) + lr_inc * current_step\n", " return lr\n", "\n", "def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):\n", " base_lr = lr\n", " warmup_init_lr = 0\n", " total_steps = int(max_epoch * steps_per_epoch)\n", " warmup_steps = int(warmup_epochs * steps_per_epoch)\n", " decay_steps = total_steps - warmup_steps\n", "\n", " lr_each_step = []\n", " for i in range(total_steps):\n", " if i < warmup_steps:\n", " lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)\n", " else:\n", " linear_decay = (total_steps - i) / decay_steps\n", " cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))\n", " decayed = linear_decay * cosine_decay + 0.00001\n", " lr = base_lr * decayed\n", " lr_each_step.append(lr)\n", "\n", " lr_each_step = np.array(lr_each_step).astype(np.float32)\n", " learning_rate = lr_each_step[global_step:]\n", " return learning_rate\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义损失函数" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from mindspore.nn.loss.loss import _Loss\n", "from mindspore.ops import operations as P\n", "from mindspore.ops import functional as F\n", "from mindspore import Tensor\n", "import mindspore.nn as nn\n", "\n", "class CrossEntropy(_Loss):\n", " def __init__(self, smooth_factor=0., num_classes=1001):\n", " super(CrossEntropy, self).__init__()\n", " self.onehot = P.OneHot()\n", " self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)\n", " self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)\n", " self.ce = nn.SoftmaxCrossEntropyWithLogits()\n", " self.mean = P.ReduceMean(False)\n", "\n", " def construct(self, logit, label):\n", " one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)\n", " loss = self.ce(logit, one_hot_label)\n", " loss = self.mean(loss, 0)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义深度神经网络" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "本篇使用的MindSpore中的ResNet-50网络模型的源代码。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from mindspore.common.tensor import Tensor\n", "import mindspore.common.initializer as weight_init\n", "\n", "def _weight_variable(shape, factor=0.01):\n", " init_value = np.random.randn(*shape).astype(np.float32) * factor\n", " return Tensor(init_value)\n", "\n", "\n", "def _conv3x3(in_channel, out_channel, stride=1):\n", " weight_shape = (out_channel, in_channel, 3, 3)\n", " weight = _weight_variable(weight_shape)\n", " return nn.Conv2d(in_channel, out_channel,\n", " kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n", "\n", "\n", "def _conv1x1(in_channel, out_channel, stride=1):\n", " weight_shape = (out_channel, in_channel, 1, 1)\n", " weight = _weight_variable(weight_shape)\n", " return nn.Conv2d(in_channel, out_channel,\n", " kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n", "\n", "\n", "def _conv7x7(in_channel, out_channel, stride=1):\n", " weight_shape = (out_channel, in_channel, 7, 7)\n", " weight = _weight_variable(weight_shape)\n", " return nn.Conv2d(in_channel, out_channel,\n", " kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n", "\n", "\n", "def _bn(channel):\n", " return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,\n", " gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)\n", "\n", "\n", "def _bn_last(channel):\n", " return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,\n", " gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)\n", "\n", "\n", "def _fc(in_channel, out_channel):\n", " weight_shape = (out_channel, in_channel)\n", " weight = _weight_variable(weight_shape)\n", " return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)\n", "\n", "\n", "class ResidualBlock(nn.Cell):\n", " expansion = 4\n", "\n", " def __init__(self,\n", " in_channel,\n", " out_channel,\n", " stride=1):\n", " super(ResidualBlock, self).__init__()\n", "\n", " channel = out_channel // self.expansion\n", " self.conv1 = _conv1x1(in_channel, channel, stride=1)\n", " self.bn1 = _bn(channel)\n", "\n", " self.conv2 = _conv3x3(channel, channel, stride=stride)\n", " self.bn2 = _bn(channel)\n", "\n", " self.conv3 = _conv1x1(channel, out_channel, stride=1)\n", " self.bn3 = _bn_last(out_channel)\n", "\n", " self.relu = nn.ReLU()\n", "\n", " self.down_sample = False\n", "\n", " if stride != 1 or in_channel != out_channel:\n", " self.down_sample = True\n", " self.down_sample_layer = None\n", "\n", " if self.down_sample:\n", " self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),\n", " _bn(out_channel)])\n", " self.add = P.TensorAdd()\n", "\n", " def construct(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.down_sample:\n", " identity = self.down_sample_layer(identity)\n", "\n", " out = self.add(out, identity)\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "class ResNet(nn.Cell):\n", "\n", " def __init__(self,\n", " block,\n", " layer_nums,\n", " in_channels,\n", " out_channels,\n", " strides,\n", " num_classes):\n", " super(ResNet, self).__init__()\n", "\n", " if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:\n", " raise ValueError(\"the length of layer_num, in_channels, out_channels list must be 4!\")\n", "\n", " self.conv1 = _conv7x7(3, 64, stride=2)\n", " self.bn1 = _bn(64)\n", " self.relu = P.ReLU()\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode=\"same\")\n", "\n", " self.layer1 = self._make_layer(block,\n", " layer_nums[0],\n", " in_channel=in_channels[0],\n", " out_channel=out_channels[0],\n", " stride=strides[0])\n", " self.layer2 = self._make_layer(block,\n", " layer_nums[1],\n", " in_channel=in_channels[1],\n", " out_channel=out_channels[1],\n", " stride=strides[1])\n", " self.layer3 = self._make_layer(block,\n", " layer_nums[2],\n", " in_channel=in_channels[2],\n", " out_channel=out_channels[2],\n", " stride=strides[2])\n", " self.layer4 = self._make_layer(block,\n", " layer_nums[3],\n", " in_channel=in_channels[3],\n", " out_channel=out_channels[3],\n", " stride=strides[3])\n", "\n", " self.mean = P.ReduceMean(keep_dims=True)\n", " self.flatten = nn.Flatten()\n", " self.end_point = _fc(out_channels[3], num_classes)\n", "\n", " def _make_layer(self, block, layer_num, in_channel, out_channel, stride):\n", " \n", " layers = []\n", "\n", " resnet_block = block(in_channel, out_channel, stride=stride)\n", " layers.append(resnet_block)\n", "\n", " for _ in range(1, layer_num):\n", " resnet_block = block(out_channel, out_channel, stride=1)\n", " layers.append(resnet_block)\n", "\n", " return nn.SequentialCell(layers)\n", "\n", " def construct(self, x):\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " c1 = self.maxpool(x)\n", "\n", " c2 = self.layer1(c1)\n", " c3 = self.layer2(c2)\n", " c4 = self.layer3(c3)\n", " c5 = self.layer4(c4)\n", "\n", " out = self.mean(c5, (2, 3))\n", " out = self.flatten(out)\n", " out = self.end_point(out)\n", "\n", " return out\n", "\n", "def resnet50(class_num=10):\n", "\n", " return ResNet(ResidualBlock,\n", " [3, 4, 6, 3],\n", " [64, 256, 512, 1024],\n", " [256, 512, 1024, 2048],\n", " [1, 2, 2, 2],\n", " class_num)\n", "\n", "def resnet101(class_num=1001):\n", " return ResNet(ResidualBlock,\n", " [3, 4, 23, 3],\n", " [64, 256, 512, 1024],\n", " [256, 512, 1024, 2048],\n", " [1, 2, 2, 2],\n", " class_num)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义回调函数Time_per_Step来计算单步训练耗时" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Time_per_Step`用于计算每步训练的时间消耗情况,方便对比混合精度训练和单精度训练的性能区别。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from mindspore.train.callback import Callback\n", "import time\n", "\n", "class Time_per_Step(Callback):\n", " def step_begin(self, run_context):\n", " cb_params = run_context.original_args()\n", " cb_params.init_time = time.time()\n", " \n", " def step_end(selfself, run_context):\n", " cb_params = run_context.original_args()\n", " one_step_time = (time.time() - cb_params.init_time) * 1000\n", " print(one_step_time, \"ms\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义训练网络" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 设置混合精度训练并执行训练" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于MindSpore已经添加了自动混合精度训练功能,我们这里操作起来非常方便,只需要在Model中添加参数`amp_level=O2`就完成了设置GPU模式下的混合精度训练设置。运行时,将会自动混合精度训练模型。\n", "\n", "`amp_level`的参数详情:\n", "\n", "`O0`:表示不做任何变化,即单精度训练,系统默认`O0`。\n", "\n", "`O2`:表示将网络中的参数计算变为float16。适用于GPU环境。\n", "\n", "`O3`:表示将网络中的参数计算变为float16,同时需要在Model中添加参数`keep_batchnorm_fp32=False`。适用于Ascend环境。\n", "\n", "在`Model`中设置`amp_level=O2`后即可执行混合精度训练:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1, loss is 2.3015203\n", "37518.837213516235 ms\n", "epoch: 1 step: 2, loss is 2.3068979\n", "197.05581665039062 ms\n", "epoch: 1 step: 3, loss is 2.3115108\n", "189.01705741882324 ms\n", "epoch: 1 step: 4, loss is 2.3279507\n", "188.4777545928955 ms\n", "epoch: 1 step: 5, loss is 2.2853572\n", "188.50111961364746 ms\n", "epoch: 1 step: 6, loss is 2.2706618\n", "188.63296508789062 ms\n", "epoch: 1 step: 7, loss is 2.325651\n", "213.5298252105713 ms\n", "epoch: 1 step: 8, loss is 2.3179858\n", "188.95459175109863 ms\n", "epoch: 1 step: 9, loss is 2.3060834\n", "193.02725791931152 ms\n", "epoch: 1 step: 10, loss is 2.39061\n", "192.83699989318848 ms\n", "\n", "\n", "......\n", "\n", "\n", "epoch: 8 step: 323, loss is 0.54335135\n", "190.31238555908203 ms\n", "epoch: 8 step: 324, loss is 0.2202819\n", "190.30189514160156 ms\n", "\n", "\n", "......\n", "\n", "\n", "epoch: 10 step: 1545, loss is 0.21533835\n", "192.63434410095215 ms\n", "epoch: 10 step: 1546, loss is 0.14042784\n", "192.5680637359619 ms\n", "epoch: 10 step: 1547, loss is 0.14810953\n", "192.64483451843262 ms\n", "epoch: 10 step: 1548, loss is 0.3791172\n", "192.7051544189453 ms\n", "epoch: 10 step: 1549, loss is 0.43446764\n", "192.60406494140625 ms\n", "epoch: 10 step: 1550, loss is 0.16453475\n", "192.5489902496338 ms\n", "epoch: 10 step: 1551, loss is 0.43192416\n", "192.45147705078125 ms\n", "epoch: 10 step: 1552, loss is 0.15318932\n", "192.69466400146484 ms\n", "epoch: 10 step: 1553, loss is 0.18142739\n", "192.4266815185547 ms\n", "epoch: 10 step: 1554, loss is 0.23418093\n", "191.3902759552002 ms\n", "epoch: 10 step: 1555, loss is 0.21376474\n", "190.4129981994629 ms\n", "epoch: 10 step: 1556, loss is 0.26256102\n", "190.1836395263672 ms\n", "epoch: 10 step: 1557, loss is 0.11623224\n", "190.07587432861328 ms\n", "epoch: 10 step: 1558, loss is 0.38422704\n", "190.2308464050293 ms\n", "epoch: 10 step: 1559, loss is 0.1297225\n", "190.05846977233887 ms\n", "epoch: 10 step: 1560, loss is 0.03785105\n", "189.89896774291992 ms\n", "epoch: 10 step: 1561, loss is 0.2947039\n", "190.0768280029297 ms\n", "epoch: 10 step: 1562, loss is 0.41113874\n", "190.03891944885254 ms\n", "Epoch time: 302610.106, per step time: 193.732\n" ] } ], "source": [ "\"\"\"train ResNet-50\"\"\"\n", "import os\n", "import random\n", "import argparse\n", "from mindspore import context\n", "from mindspore.nn.optim.momentum import Momentum\n", "from mindspore.train.model import Model\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor\n", "from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n", "from mindspore.train.loss_scale_manager import FixedLossScaleManager\n", "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", "\n", "\n", "\n", "parser = argparse.ArgumentParser(description='Image classification')\n", "parser.add_argument('--net', type=str, default=\"resnet50\", help='Resnet Model, either resnet50 or resnet101')\n", "parser.add_argument('--dataset', type=str, default=\"cifar10\", help='Dataset, either cifar10 or imagenet2012')\n", "parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')\n", "parser.add_argument('--device_target', type=str, default='GPU', help='Device target')\n", "parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')\n", "args_opt = parser.parse_known_args()[0]\n", "\n", "random.seed(1)\n", "np.random.seed(1)\n", "de.config.set_seed(1)\n", "\n", "if __name__ == '__main__':\n", "\n", " context.set_context(mode=context.GRAPH_MODE,enable_auto_mixed_precision=False, device_target=\"GPU\")\n", " ckpt_save_dir= \"./resnet_ckpt\"\n", " batch_size = 32\n", " epoch_size = 10\n", " dataset_path = \"./datasets/cifar10/train\"\n", " test_path = \"./datasets/cifar10/test\"\n", " \n", " # create dataset\n", " dataset = create_dataset(dataset_path=dataset_path, do_train=True, repeat_num=1,\n", " batch_size=batch_size, target=\"GPU\")\n", " step_size = dataset.get_dataset_size()\n", " # define net\n", " net = resnet50(class_num=10)\n", " \n", " # init weight\n", " if args_opt.pre_trained:\n", " param_dict = load_checkpoint(args_opt.pre_trained)\n", " load_param_into_net(net, param_dict)\n", " else:\n", " for _, cell in net.cells_and_names():\n", " if isinstance(cell, nn.Conv2d):\n", " cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),\n", " cell.weight.default_input.shape,\n", " cell.weight.default_input.dtype).to_tensor()\n", " if isinstance(cell, nn.Dense):\n", " cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),\n", " cell.weight.default_input.shape,\n", " cell.weight.default_input.dtype).to_tensor()\n", " # init lr\n", " warmup_epochs = 5\n", " lr_init = 0.01\n", " lr_end = 0.00001\n", " lr_max = 0.1\n", " lr = get_lr(lr_init=lr_init, lr_end=lr_end, lr_max=lr_max,\n", " warmup_epochs=warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,\n", " lr_decay_mode='poly')\n", " lr = Tensor(lr)\n", "\n", " # define opt\n", " loss_scale = 1024\n", " momentum = 0.9\n", " weight_decay = 1e-4\n", " \n", " # define loss, model\n", " loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean')\n", " opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)\n", " model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},amp_level=\"O2\")\n", " \n", " # define callbacks\n", " steptime_cb = Time_per_Step()\n", " time_cb = TimeMonitor(data_size=step_size)\n", " loss_cb = LossMonitor()\n", "\n", " cb = [time_cb, loss_cb,steptime_cb]\n", " save_checkpoint = 5\n", " if save_checkpoint:\n", " save_checkpoint_epochs = 5\n", " keep_checkpoint_max = 10\n", " config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_epochs * step_size,\n", " keep_checkpoint_max=keep_checkpoint_max)\n", " ckpt_cb = ModelCheckpoint(prefix=\"resnet\", directory=ckpt_save_dir, config=config_ck)\n", " cb += [ckpt_cb]\n", "\n", " # train model\n", " model.train(epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 验证模型精度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用模型进行精度验证可以得出以下代码。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: {'acc': 0.8796073717948718}\n" ] } ], "source": [ "# Eval model\n", "eval_dataset_path = \"./datasets/cifar10/test\"\n", "eval_data = create_dataset(eval_dataset_path,do_train=False)\n", "acc = model.eval(eval_data,dataset_sink_mode=True)\n", "print(\"Accuracy:\",acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 对比不同网络下的混合精度训练和单精度训练的差别" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于篇幅原因,我们这里只展示了ResNet-50网络的混合精度训练情况。可以在主程序入口的Model中设置参数`amp_level = O0`进行单精度训练,训练完毕后,将结果进行对比,看看两者的情况,下面将我测试的情况做成表格如下。(训练时,笔者使用的GPU为Nvidia Tesla P40,不同的硬件对训练的效率影响较大,下述表格中的数据仅供参考)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| 网络 | 是否混合训练 | 单步训练时间 | epoch | Accuracy\n", "|:------ |:-----| :------- |:--- |:------ \n", "|ResNet-50 | 否 | 232ms | 10 | 0.881809 \n", "|ResNet-50 | 是 | 192ms | 10 | 0.879607 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过多次测试,使用ResNet-50网络,CIFAR-10数据集,进行混合精度训练对整体的训练效率提升了16%,而且对最终模型的精度影响不大,对整体性能调优来说是一个不容忽视的性能提升。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当然,如果你想参考单步训练或者手动设置混合精度训练,可以参考官网教程。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "本次体验我们尝试了在ResNet-50网络中使用混合精度来进行模型训练,并对比了单精度下的训练过程,了解到了混合精度训练的原理和对模型训练的提升效果。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }