提交 f1318a68 编写于 作者: L lvmingfu

unify code formats in notebook

上级 2ebef933
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
"接下来我们用一个图片分类的项目来体验计算图与数据图的生成与使用。\n", "接下来我们用一个图片分类的项目来体验计算图与数据图的生成与使用。\n",
" \n", " \n",
"## 本次体验的整体流程\n", "## 本次体验的整体流程\n",
"1体验模型的数据选择使用MNIST数据集,MNIST数据集整体数据量比较小,更适合体验使用。\n", "1. 体验模型的数据选择使用MNIST数据集,MNIST数据集整体数据量比较小,更适合体验使用。\n",
"\n", "\n",
"2初始化一个网络,本次的体验使用LeNet网络。\n", "2. 初始化一个网络,本次的体验使用LeNet网络。\n",
"\n", "\n",
"3增加可视化功能的使用,并设定只记录计算图与数据图。\n", "3. 增加可视化功能的使用,并设定只记录计算图与数据图。\n",
"\n", "\n",
"4加载训练数据集并进行训练,训练完成后,查看结果并保存模型文件。\n", "4. 加载训练数据集并进行训练,训练完成后,查看结果并保存模型文件。\n",
"\n", "\n",
"5启用MindInsight的可视化图界面,进行训练过程的核对。" "5. 启用MindInsight的可视化图界面,进行训练过程的核对。"
] ]
}, },
{ {
...@@ -35,8 +35,8 @@ ...@@ -35,8 +35,8 @@
"\n", "\n",
"从以下网址下载,并将数据包解压后放在Jupyter的工作目录下。\n", "从以下网址下载,并将数据包解压后放在Jupyter的工作目录下。\n",
"\n", "\n",
"- 训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\",\"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "- 训练数据集:{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
"- 测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\",\"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n", "- 测试数据集:{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}\n",
"\n", "\n",
"可执行下面代码查看Jupyter的工作目录。" "可执行下面代码查看Jupyter的工作目录。"
] ]
...@@ -187,11 +187,11 @@ ...@@ -187,11 +187,11 @@
"source": [ "source": [
"### 可视化操作流程\n", "### 可视化操作流程\n",
"\n", "\n",
"1准备训练脚本,在训练脚本中指定计算图的超参数信息,使用`Summary`保存到日志中,接着再运行训练脚本。\n", "1. 准备训练脚本,在训练脚本中指定计算图的超参数信息,使用`Summary`保存到日志中,接着再运行训练脚本。\n",
"\n", "\n",
"2启动MindInsight,启动成功后,就可以通过访问命令执行后显示的地址,查看可视化界面。\n", "2. 启动MindInsight,启动成功后,就可以通过访问命令执行后显示的地址,查看可视化界面。\n",
"\n", "\n",
"3访问可视化地址成功后,就可以对图界面进行查询等操作。" "3. 访问可视化地址成功后,就可以对图界面进行查询等操作。"
] ]
}, },
{ {
...@@ -200,11 +200,11 @@ ...@@ -200,11 +200,11 @@
"source": [ "source": [
"#### 初始化网络\n", "#### 初始化网络\n",
"\n", "\n",
"1导入构建网络所使用的模块。\n", "1. 导入构建网络所使用的模块。\n",
"\n", "\n",
"2构建初始化参数的函数。\n", "2. 构建初始化参数的函数。\n",
"\n", "\n",
"3创建网络,在网络中设置参数。" "3. 创建网络,在网络中设置参数。"
] ]
}, },
{ {
...@@ -273,11 +273,11 @@ ...@@ -273,11 +273,11 @@
"source": [ "source": [
"#### 主程序运行\n", "#### 主程序运行\n",
"\n", "\n",
"1首先在主函数之前调用所需要的模块,并在主函数之前使用相应接口。\n", "1. 首先在主函数之前调用所需要的模块,并在主函数之前使用相应接口。\n",
"\n", "\n",
"2本次体验主要完成计算图与数据图的可视化,定义变量`specified={'collect_graph': True,'collect_dataset_graph': True}`,在`specified`字典中,键名`collect_graph`值设置为`True`,表示记录计算图;键名`collect_dataset_graph`值设置为`True`,表示记录数据图。\n", "2. 本次体验主要完成计算图与数据图的可视化,定义变量`specified={'collect_graph': True,'collect_dataset_graph': True}`,在`specified`字典中,键名`collect_graph`值设置为`True`,表示记录计算图;键名`collect_dataset_graph`值设置为`True`,表示记录数据图。\n",
"\n", "\n",
"3定义完`specified`变量后,传参到`summary_collector`中,最后将`summary_collector`传参到`model`中。\n", "3. 定义完`specified`变量后,传参到`summary_collector`中,最后将`summary_collector`传参到`model`中。\n",
"\n", "\n",
"至此,模型中就有了计算图与数据图的可视化功能。" "至此,模型中就有了计算图与数据图的可视化功能。"
] ]
...@@ -332,7 +332,7 @@ ...@@ -332,7 +332,7 @@
"- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n", "- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n",
"- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n", "- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n",
"\n", "\n",
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/005.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/mindinsight_map.png)"
] ]
}, },
{ {
...@@ -346,7 +346,7 @@ ...@@ -346,7 +346,7 @@
"- 节点信息:显示当前所查看节点的信息,包括名称、类型、属性、输入和输出。便于在训练结束后,核对计算正确性时查看。\n", "- 节点信息:显示当前所查看节点的信息,包括名称、类型、属性、输入和输出。便于在训练结束后,核对计算正确性时查看。\n",
"- 图例:图例中包括命名空间、聚合节点、虚拟节点、算子节点、常量节点,通过不同图形来区分。\n", "- 图例:图例中包括命名空间、聚合节点、虚拟节点、算子节点、常量节点,通过不同图形来区分。\n",
"\n", "\n",
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/004.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/cast_map.png)"
] ]
}, },
{ {
...@@ -357,11 +357,11 @@ ...@@ -357,11 +357,11 @@
"\n", "\n",
"数据图所展示的顺序与数据集使用处代码顺序对应\n", "数据图所展示的顺序与数据集使用处代码顺序对应\n",
"\n", "\n",
"1首先是从加载数据集`mnist_ds = ds.MnistDataset(data_path)`开始,对应数据图中`MnistDataset`。\n", "1. 首先是从加载数据集`mnist_ds = ds.MnistDataset(data_path)`开始,对应数据图中`MnistDataset`。\n",
"\n", "\n",
"2在以下所示代码中,是数据预处理的一些方法,顺序与数据图中所示顺序对应。\n", "2. 在以下所示代码中,是数据预处理的一些方法,顺序与数据图中所示顺序对应。\n",
"\n", "\n",
"`\n", "```\n",
"type_cast_op = C.TypeCast(mstype.int32)\n", "type_cast_op = C.TypeCast(mstype.int32)\n",
"resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n", "resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
"rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n", "rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
...@@ -372,21 +372,22 @@ ...@@ -372,21 +372,22 @@
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
"`\n", "```\n",
"\n", "\n",
"- `TypeCast`:在数据集`create_data`函数中,使用:`TypeCase(mstype.int32)`,将数据类型转换成我们所设置的类型。\n", "- `TypeCast`:在数据集`create_data`函数中,使用:`TypeCase(mstype.int32)`,将数据类型转换成我们所设置的类型。\n",
"- `Resize`:在数据集`create_data`函数中,使用:`Resize(resize_height,resize_width = 32,32)`,可以将数据的高和宽做调整。\n", "- `Resize`:在数据集`create_data`函数中,使用:`Resize(resize_height,resize_width = 32,32)`,可以将数据的高和宽做调整。\n",
"- `Rescale`:在数据集`create_data`函数中,使用:`rescale = 1.0 / 255.0`;`Rescale(rescale,shift)`,可以重新数据格式。\n", "- `Rescale`:在数据集`create_data`函数中,使用:`rescale = 1.0 / 255.0`;`Rescale(rescale,shift)`,可以重新数据格式。\n",
"- `HWC2CHW`:在数据集`create_data`函数中,使用:`HWC2CHW()`,此方法可以将数据所带信息与通道结合,一并加载。\n", "- `HWC2CHW`:在数据集`create_data`函数中,使用:`HWC2CHW()`,此方法可以将数据所带信息与通道结合,一并加载。\n",
"\n", "\n",
"3、前面的几个步骤是数据集的预处理顺序,后面几个步骤是模型加载数据集时要定义的参数,顺序与数据图中对应。\n",
"\n", "\n",
"`\n", "3. 前面的几个步骤是数据集的预处理顺序,后面几个步骤是模型加载数据集时要定义的参数,顺序与数据图中对应。\n",
"\n",
"```\n",
"buffer_size = 10000\n", "buffer_size = 10000\n",
"mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", "mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
"mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", "mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
"mnist_ds = mnist_ds.repeat(repeat_size)\n", "mnist_ds = mnist_ds.repeat(repeat_size)\n",
"`\n", "```\n",
" \n", " \n",
"- `Shuffle`:在数据集`create_data`函数中,使用:`buffer_size = 10000`,后面数值可以支持自行设置,表示一次缓存数据的数量。\n", "- `Shuffle`:在数据集`create_data`函数中,使用:`buffer_size = 10000`,后面数值可以支持自行设置,表示一次缓存数据的数量。\n",
"- `Batch`:在数据集`create_data`函数中,使用:`batch_size = 32`。支持自行设置,表示将整体数据集划分成小批量数据集,每一个小批次作为一个整体进行训练。\n", "- `Batch`:在数据集`create_data`函数中,使用:`batch_size = 32`。支持自行设置,表示将整体数据集划分成小批量数据集,每一个小批次作为一个整体进行训练。\n",
...@@ -397,7 +398,7 @@ ...@@ -397,7 +398,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/001.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/data_map.png)"
] ]
}, },
{ {
......
...@@ -26,38 +26,38 @@ ...@@ -26,38 +26,38 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"1数据集的准备,这里使用的是MNIST数据集。\n", "1. 数据集的准备,这里使用的是MNIST数据集。\n",
"\n", "\n",
"2构建一个网络,这里使用LeNet网络。(此处将使用第二种记录方式`ImageSummary`)。\n", "2. 构建一个网络,这里使用LeNet网络。(此处将使用第二种记录方式`ImageSummary`)。\n",
"\n", "\n",
"3训练网络和测试网络的搭建及运行。(此处将操作`SummaryCollector`初始化,并记录模型训练和模型测试相关信息)。\n", "3. 训练网络和测试网络的搭建及运行。(此处将操作`SummaryCollector`初始化,并记录模型训练和模型测试相关信息)。\n",
"\n", "\n",
"4启动MindInsight服务。\n", "4. 启动MindInsight服务。\n",
"\n", "\n",
"5模型溯源的使用。调整模型参数多次存储数据,并使用MindInsight的模型溯源功能对不同优化参数下训练产生的模型作对比,了解MindSpore中的各类优化对训练过程的影响及如何调优训练过程。\n", "5. 模型溯源的使用。调整模型参数多次存储数据,并使用MindInsight的模型溯源功能对不同优化参数下训练产生的模型作对比,了解MindSpore中的各类优化对训练过程的影响及如何调优训练过程。\n",
"\n", "\n",
"6数据溯源的使用。调整数据参数多次存储数据,并使用MindInsight的数据溯源功能对不同数据集下训练产生的模型进行对比分析,了解如何调优。" "6. 数据溯源的使用。调整数据参数多次存储数据,并使用MindInsight的数据溯源功能对不同数据集下训练产生的模型进行对比分析,了解如何调优。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"本次体验将使用快速入门案例作为基础用例,将MindInsight的模型溯源和数据溯源的数据记录功能加入到案例中,快速入门案例的源码请参考:https://gitee.com/mindspore/docs/blob/r0.5/tutorials/tutorial_code/lenet.py 。" "本次体验将使用快速入门案例作为基础用例,将MindInsight的模型溯源和数据溯源的数据记录功能加入到案例中,快速入门案例的源码请参考:<https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/lenet.py>。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 一、训练的数据集下载" "## 训练的数据集下载"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、数据集准备" "### 数据集准备"
] ]
}, },
{ {
...@@ -65,8 +65,8 @@ ...@@ -65,8 +65,8 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### 方法一:\n", "#### 方法一:\n",
"从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
"<br/>测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}<br/>我们用下面代码查询jupyter的工作目录。" "<br/>测试数据集:{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}<br/>我们用下面代码查询jupyter的工作目录。"
] ]
}, },
{ {
...@@ -100,15 +100,14 @@ ...@@ -100,15 +100,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Network request module, data download module, decompression module\n",
"import urllib.request \n", "import urllib.request \n",
"from urllib.parse import urlparse\n", "from urllib.parse import urlparse\n",
"import gzip \n", "import gzip \n",
"\n", "\n",
"def unzipfile(gzip_path):\n", "def unzip_file(gzip_path):\n",
" \"\"\"unzip dataset file\n", " \"\"\"unzip dataset file\n",
" Args:\n", " Args:\n",
" gzip_path: dataset file path\n", " gzip_path (str): Dataset file path\n",
" \"\"\"\n", " \"\"\"\n",
" open_file = open(gzip_path.replace('.gz',''), 'wb')\n", " open_file = open(gzip_path.replace('.gz',''), 'wb')\n",
" gz_file = gzip.GzipFile(gzip_path)\n", " gz_file = gzip.GzipFile(gzip_path)\n",
...@@ -134,7 +133,7 @@ ...@@ -134,7 +133,7 @@
" file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
" \n", " \n",
" for url in test_url:\n", " for url in test_url:\n",
...@@ -143,7 +142,7 @@ ...@@ -143,7 +142,7 @@
" file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
"\n", "\n",
"download_dataset()" "download_dataset()"
...@@ -160,7 +159,7 @@ ...@@ -160,7 +159,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、数据集处理" "### 数据集处理"
] ]
}, },
{ {
...@@ -169,11 +168,11 @@ ...@@ -169,11 +168,11 @@
"source": [ "source": [
"数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。\n", "数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。\n",
"<br/>我们定义一个函数`create_dataset`来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n", "<br/>我们定义一个函数`create_dataset`来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n",
"<br/>1、定义数据集。\n", "1. 定义数据集。\n",
"<br/>2、定义进行数据增强和处理所需要的一些参数。\n", "2. 定义进行数据增强和处理所需要的一些参数。\n",
"<br/>3、根据参数,生成对应的数据增强操作。\n", "3. 根据参数,生成对应的数据增强操作。\n",
"<br/>4、使用`map`映射函数,将数据操作应用到数据集。\n", "4. 使用`map`映射函数,将数据操作应用到数据集。\n",
"<br/>5、对生成的数据集进行处理。" "5. 对生成的数据集进行处理。"
] ]
}, },
{ {
...@@ -199,36 +198,36 @@ ...@@ -199,36 +198,36 @@
" num_parallel_workers=1):\n", " num_parallel_workers=1):\n",
" \"\"\" create dataset for train or test\n", " \"\"\" create dataset for train or test\n",
" Args:\n", " Args:\n",
" data_path: Data path\n", " data_path (str): Data path\n",
" batch_size: The number of data records in each group\n", " batch_size (int): The number of data records in each group\n",
" repeat_size: The number of replicated data records\n", " repeat_size (int): The number of replicated data records\n",
" num_parallel_workers: The number of parallel workers\n", " num_parallel_workers (int): The number of parallel workers\n",
" \"\"\"\n", " \"\"\"\n",
" # define dataset\n", " # define dataset\n",
" mnist_ds = ds.MnistDataset(data_path)\n", " mnist_ds = ds.MnistDataset(data_path)\n",
"\n", "\n",
" # Define some parameters needed for data enhancement and rough justification\n", " # define some parameters needed for data enhancement and rough justification\n",
" resize_height, resize_width = 32, 32\n", " resize_height, resize_width = 32, 32\n",
" rescale = 1.0 / 255.0\n", " rescale = 1.0 / 255.0\n",
" shift = 0.0\n", " shift = 0.0\n",
" rescale_nml = 1 / 0.3081\n", " rescale_nml = 1 / 0.3081\n",
" shift_nml = -1 * 0.1307 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n",
"\n", "\n",
" # According to the parameters, generate the corresponding data enhancement method\n", " # according to the parameters, generate the corresponding data enhancement method\n",
" resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) by bilinear interpolation\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
" rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
" rescale_op = CV.Rescale(rescale, shift) # rescale images\n", " rescale_op = CV.Rescale(rescale, shift)\n",
" hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n", " hwc2chw_op = CV.HWC2CHW()\n",
" type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n", " type_cast_op = C.TypeCast(mstype.int32)\n",
"\n", "\n",
" # Using map() to apply operations to a dataset\n", " # using map method to apply operations to a dataset\n",
" mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
" \n", " \n",
" # Process the generated dataset\n", " # process the generated dataset\n",
" buffer_size = 10000\n", " buffer_size = 10000\n",
" mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
" mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
...@@ -241,7 +240,7 @@ ...@@ -241,7 +240,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 二、构建LeNet5网络" "## 构建LeNet5网络"
] ]
}, },
{ {
...@@ -260,7 +259,7 @@ ...@@ -260,7 +259,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -268,7 +267,6 @@ ...@@ -268,7 +267,6 @@
"import mindspore.nn as nn\n", "import mindspore.nn as nn\n",
"from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.common.initializer import TruncatedNormal\n",
"\n", "\n",
"# Initialize 2D convolution function\n",
"def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
" \"\"\"Conv layer weight initial.\"\"\"\n", " \"\"\"Conv layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
...@@ -276,50 +274,46 @@ ...@@ -276,50 +274,46 @@
" kernel_size=kernel_size, stride=stride, padding=padding,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n",
" weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
"\n", "\n",
"# Initialize full connection layer\n",
"def fc_with_initialize(input_channels, out_channels):\n", "def fc_with_initialize(input_channels, out_channels):\n",
" \"\"\"Fc layer weight initial.\"\"\"\n", " \"\"\"Fc layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
" bias = weight_variable()\n", " bias = weight_variable()\n",
" return nn.Dense(input_channels, out_channels, weight, bias)\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n",
"\n", "\n",
"# Set truncated normal distribution\n",
"def weight_variable():\n", "def weight_variable():\n",
" \"\"\"Weight initial.\"\"\"\n", " \"\"\"Weight initial.\"\"\"\n",
" return TruncatedNormal(0.02)\n", " return TruncatedNormal(0.02)\n",
"\n", "\n",
"class LeNet5(nn.Cell):\n", "class LeNet5(nn.Cell):\n",
" \"\"\"Lenet network structure.\"\"\"\n", " \"\"\"Lenet network structure.\"\"\"\n",
" # define the operator required\n",
" def __init__(self):\n", " def __init__(self):\n",
" super(LeNet5, self).__init__()\n", " super(LeNet5, self).__init__()\n",
" self.batch_size = 32 # 32 pictures in each group\n", " self.batch_size = 32 \n",
" self.conv1 = conv(1, 6, 5) # Convolution layer 1, 1 channel input (1 Figure), 6 channel output (6 figures), convolution core 5 * 5\n", " self.conv1 = conv(1, 6, 5)\n",
" self.conv2 = conv(6, 16, 5) # Convolution layer 2,6-channel input, 16 channel output, convolution kernel 5 * 5\n", " self.conv2 = conv(6, 16, 5)\n",
" self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
" self.fc2 = fc_with_initialize(120, 84)\n", " self.fc2 = fc_with_initialize(120, 84)\n",
" self.fc3 = fc_with_initialize(84, 10)\n", " self.fc3 = fc_with_initialize(84, 10)\n",
" self.relu = nn.ReLU()\n", " self.relu = nn.ReLU()\n",
" self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.flatten = nn.Flatten()\n", " self.flatten = nn.Flatten()\n",
" #Init ImageSummary\n", " # Init ImageSummary\n",
" self.sm_image = P.ImageSummary()\n", " self.sm_image = P.ImageSummary()\n",
"\n", "\n",
" # use the preceding operators to construct networks\n",
" def construct(self, x):\n", " def construct(self, x):\n",
" self.sm_image(\"image\",x)\n", " self.sm_image(\"image\",x)\n",
" x = self.conv1(x) # 1*32*32-->6*28*28\n", " x = self.conv1(x)\n",
" x = self.relu(x) # 6*28*28-->6*14*14\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.conv2(x) # Convolution layer\n", " x = self.conv2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.flatten(x) # Dimensionality reduction\n", " x = self.flatten(x)\n",
" x = self.fc1(x) # Full connection\n", " x = self.fc1(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc2(x) # Full connection\n", " x = self.fc2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc3(x) # Full connection\n", " x = self.fc3(x)\n",
" return x" " return x"
] ]
}, },
...@@ -327,14 +321,14 @@ ...@@ -327,14 +321,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 三、训练网络和测试网络构建" "## 训练网络和测试网络构建"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、使用SummaryCollector放入到训练网络中记录训练数据" "### 使用SummaryCollector放入到训练网络中记录训练数据"
] ]
}, },
{ {
...@@ -350,12 +344,10 @@ ...@@ -350,12 +344,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Training and testing related modules\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"import os\n", "import os\n",
"\n", "\n",
"\n",
"def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n", "def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n",
" \"\"\"Define the training method.\"\"\"\n", " \"\"\"Define the training method.\"\"\"\n",
" print(\"============== Starting Training ==============\")\n", " print(\"============== Starting Training ==============\")\n",
...@@ -368,7 +360,7 @@ ...@@ -368,7 +360,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、使用SummaryCollector放入到测试网络中记录测试数据" "### 使用SummaryCollector放入到测试网络中记录测试数据"
] ]
}, },
{ {
...@@ -380,7 +372,7 @@ ...@@ -380,7 +372,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -403,14 +395,14 @@ ...@@ -403,14 +395,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3、主程序运行入口" "### 主程序运行入口"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"初始化`SummaryCollector`,使用`collect_specified_data`控制需要记录的数据,我们这里只需要记录模型溯源和数据溯源,所以将`collect_train_lineage`和`collect_eval_lineage`参数设置成`True`,其他的参数使用`keep_default_action`设置成`False`,SummaryCollector能够记录哪些数据,请参考官网:<https://www.mindspore.cn/api/zh-CN/r0.5/api/python/mindspore/mindspore.train.html#mindspore.train.callback.SummaryCollector> 。" "初始化`SummaryCollector`,使用`collect_specified_data`控制需要记录的数据,我们这里只需要记录模型溯源和数据溯源,所以将`collect_train_lineage`和`collect_eval_lineage`参数设置成`True`,其他的参数使用`keep_default_action`设置成`False`,SummaryCollector能够记录哪些数据,请参考官网:<https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.train.html?highlight=collector#mindspore.train.callback.SummaryCollector>。"
] ]
}, },
{ {
...@@ -427,7 +419,7 @@ ...@@ -427,7 +419,7 @@
"\n", "\n",
"if __name__==\"__main__\":\n", "if __name__==\"__main__\":\n",
" context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n", " context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n",
" lr = 0.01 # learning rate\n", " lr = 0.01\n",
" momentum = 0.9 \n", " momentum = 0.9 \n",
" epoch_size = 3\n", " epoch_size = 3\n",
" mnist_path = \"./MNIST_Data\"\n", " mnist_path = \"./MNIST_Data\"\n",
...@@ -453,14 +445,14 @@ ...@@ -453,14 +445,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 四、启动及关闭MindInsight服务" "## 启动及关闭MindInsight服务"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"这里主要展示如何启用及关闭MindInsight,更多的命令集信息,请参考MindSpore官方网站:https://www.mindspore.cn/tutorial/zh-CN/r0.5/advanced_use/visualization_tutorials.html 。" "这里主要展示如何启用及关闭MindInsight,更多的命令集信息,请参考MindSpore官方网站:<https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/visualization_tutorials.html>。"
] ]
}, },
{ {
...@@ -489,7 +481,7 @@ ...@@ -489,7 +481,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"查询是否启动成功,在网址输入:`127.0.0.1:8090`,如果看到如下界面说明启动成功。" "查询是否启动成功,在网址输入`127.0.0.1:8090`,如果看到如下界面说明启动成功。"
] ]
}, },
{ {
...@@ -514,14 +506,14 @@ ...@@ -514,14 +506,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 五、模型溯源" "## 模型溯源"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、连接到模型溯源地址" "### 连接到模型溯源地址"
] ]
}, },
{ {
...@@ -575,7 +567,7 @@ ...@@ -575,7 +567,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、观察分析记录下来的溯源参数" "### 观察分析记录下来的溯源参数"
] ]
}, },
{ {
...@@ -604,21 +596,21 @@ ...@@ -604,21 +596,21 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 六、数据溯源" "## 数据溯源"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、连接到数据溯源地址" "### 连接到数据溯源地址"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"浏览器中输入:127.0.0.1:8090连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:" "浏览器中输入:`127.0.0.1:8090`连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:"
] ]
}, },
{ {
...@@ -654,7 +646,7 @@ ...@@ -654,7 +646,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、观察分析数据溯源参数" "### 观察分析数据溯源参数"
] ]
}, },
{ {
......
...@@ -17,31 +17,31 @@ ...@@ -17,31 +17,31 @@
"\n", "\n",
"本例子会实现一个简单的图片分类的功能,整体流程如下:\n", "本例子会实现一个简单的图片分类的功能,整体流程如下:\n",
"\n", "\n",
"1处理需要的数据集,这里使用了MNIST数据集。\n", "1. 处理需要的数据集,这里使用了MNIST数据集。\n",
"\n", "\n",
"2定义一个网络,这里我们使用LeNet网络。\n", "2. 定义一个网络,这里我们使用LeNet网络。\n",
"\n", "\n",
"3定义损失函数和优化器。\n", "3. 定义损失函数和优化器。\n",
"\n", "\n",
"4加载数据集并进行训练,训练完成后,查看结果及保存模型文件。\n", "4. 加载数据集并进行训练,训练完成后,查看结果及保存模型文件。\n",
"\n", "\n",
"5加载保存的模型,进行推理。\n", "5. 加载保存的模型,进行推理。\n",
"\n", "\n",
"6验证模型,加载测试数据集和训练后的模型,验证结果精度。" "6. 验证模型,加载测试数据集和训练后的模型,验证结果精度。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"说明:<br/>你可以在这里找到完整可运行的样例代码:https://gitee.com/mindspore/docs/blob/r0.5/tutorials/tutorial_code/lenet.py" "说明:<br/>你可以在这里找到完整可运行的样例代码:<https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/lenet.py>。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 一、训练的数据集下载" "## 训练的数据集下载"
] ]
}, },
{ {
...@@ -49,15 +49,26 @@ ...@@ -49,15 +49,26 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### 方法一:\n", "#### 方法一:\n",
"从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
"<br/>测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}<br/>我们用下面代码查询jupyter的工作目录。" "<br/>测试数据集:{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}<br/>我们用下面代码查询jupyter的工作目录。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"'C:\\\\Users\\\\Administrator'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"import os\n", "import os\n",
"os.getcwd()" "os.getcwd()"
...@@ -67,7 +78,7 @@ ...@@ -67,7 +78,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"训练数据集放在----Jupyter工作目录+\\MNIST_Data\\train\\,此时train文件夹内应该包含两个文件,train-images-idx3-ubyte和train-labels-idx1-ubyte <br/>测试数据集放在----Jupyter工作目录+\\MNIST_Data\\test\\,此时test文件夹内应该包含两个文件,t10k-images-idx3-ubyte和t10k-labels-idx1-ubyte" "训练数据集放在----`Jupyter工作目录+\\MNIST_Data\\train\\`,此时train文件夹内应该包含两个文件,`train-images-idx3-ubyte`和`train-labels-idx1-ubyte` <br/>测试数据集放在----`Jupyter工作目录+\\MNIST_Data\\test\\`,此时test文件夹内应该包含两个文件,`t10k-images-idx3-ubyte`和`t10k-labels-idx1-ubyte`"
] ]
}, },
{ {
...@@ -80,11 +91,18 @@ ...@@ -80,11 +91,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"******Downloading the MNIST dataset******\n"
]
}
],
"source": [ "source": [
"# Network request module, data download module, decompression module\n",
"import urllib.request \n", "import urllib.request \n",
"from urllib.parse import urlparse\n", "from urllib.parse import urlparse\n",
"import gzip \n", "import gzip \n",
...@@ -144,7 +162,7 @@ ...@@ -144,7 +162,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 二、处理MNIST数据集" "## 处理MNIST数据集"
] ]
}, },
{ {
...@@ -158,7 +176,7 @@ ...@@ -158,7 +176,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"更多的LeNet网络的介绍不在此赘述,希望详细了解LeNet网络,可以查询http://yann.lecun.com/exdb/lenet/ 。" "更多的LeNet网络的介绍不在此赘述,希望详细了解LeNet网络,可以查询<http://yann.lecun.com/exdb/lenet/>。"
] ]
}, },
{ {
...@@ -170,9 +188,33 @@ ...@@ -170,9 +188,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The type of mnist_ds: <class 'mindspore.dataset.engine.datasets.MnistDataset'>\n",
"Number of pictures contained in the mnist_ds: 60000\n",
"The item of mnist_ds: dict_keys(['image', 'label'])\n",
"Tensor of image in item: (28, 28, 1)\n",
"The label of item: 9\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAANnklEQVR4nO3df6xkZX3H8fendF0i2pSVAitSoRRTqbGruUUTmtaGiohNgD8kkoZgYlwaNakJNhraRNLUhDRVa2P9sRTK0ipCqgTa0CrFJgRtCBeyLiBUKEFZd2U1SAWr6wLf/nEP7eVyf+3MmTlzed6vZDJnzo8533tyP/c5c54590lVIemF7+eGLkDSdBh2qRGGXWqEYZcaYdilRhh2qRGGXc+T5OEkvzd0HeqXYdfUJXl1kq8m+e8kDyY5d+iaWmDYNTFJfn6FeTcA/wxsAbYD/5DkVVMurzmGfQPpTq8/kGR31ypem+TwJO9MctuSdSvJr3bTVyX5VJJ/SfJkkq8lOTbJXyX5YZL7k7xuye5+M8k3u+V/l+TwRe/9+0l2JXk8ydeTvHZJjR9Mshv48TKB/zXg5cDHq+rpqvoq8DXggh4PlZZh2Dee84AzgROB1wLvPITt/hQ4CjgA/AdwV/f6H4GPLVn/D4C3ACcBr+q2JcnrgSuBi4CXAZ8FbkyyedG25wNvA36xqp7q/tB8qluWZWoL8Jp1/hwakWHfeP66qvZW1WPAPwHb1rnd9VV1Z1X9FLge+GlVXV1VTwPXAktb9k9W1SPdfj7CQoAB3g18tqpu71rmnSz88XjjkhofqaqfAFTVe6rqPd2y+4H9wB8n2ZTkDOB3gBcfykHQoTPsG8/3Fk3/D/CSdW736KLpnyzzeun7PLJo+tssnHoDvBK4uDuFfzzJ48Dxi5Yv3fY5quogcA4LLf/3gIuB64A96/w5NKLnXUDRhvRjFrWMSY7t4T2PXzT9y8DebvoR4CNV9ZFVtl31Vsqq2s1Caw5Akq8DO0esU+tky/7C8A3g15Ns6y6kXdrDe743ySuSbAEuYeFUH+By4A+TvCELjkjytiQvXe8bJ3ltd2HxxUk+AGwFruqhZq3CsL8AVNW3gD8D/g14ALht9S3W5fPAV4CHusefd/uaZ+Fz+yeBHwIPssZFwiSfSfKZRbMuAPax8Nn9dODNVXWgh5q1ivjPK6Q22LJLjTDsUiMMu9QIwy41Yqr97C/K5jqcI6a5S6kpP+XH/KwOLPeV5PHCnuRM4BPAYcDfVtVlq61/OEfwhpw+zi4lreL2umXFZSOfxic5DPgb4K3AKcD5SU4Z9f0kTdY4n9lPBR6sqoeq6mfAF4Cz+ylLUt/GCftxPPeGhz3dvOdIsj3JfJL5g/glKWko44R9uYsAz/s6XlXtqKq5qprbxOZlNpE0DeOEfQ/PvTPqFfz/nVGSZsw4Yb8DODnJiUleBLwDuLGfsiT1beSut+7fDb0P+DILXW9XVtW9vVUmqVdj9bNX1U3ATT3VImmC/Lqs1AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71IixRnHVdHx5766hS5iIt7x829AlNGWssCd5GHgCeBp4qqrm+ihKUv/6aNl/t6p+0MP7SJogP7NLjRg37AV8JcmdSbYvt0KS7Unmk8wf5MCYu5M0qnFP40+rqr1JjgZuTnJ/Vd26eIWq2gHsAPiFbKkx9ydpRGO17FW1t3veD1wPnNpHUZL6N3LYkxyR5KXPTgNnAPf0VZikfo1zGn8McH2SZ9/n81X1r71U1ZgXaj/6Wtb6ue2H79fIYa+qh4Df6LEWSRNk15vUCMMuNcKwS40w7FIjDLvUCG9xbdxa3Vutdgu+ENmyS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCPvZp2DIvupxbxMdsh/eW2D7ZcsuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIj7GfvwdD3fA/Z3+z98BuHLbvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIj1gx7kiuT7E9yz6J5W5LcnOSB7vnIyZYpaVzradmvAs5cMu9DwC1VdTJwS/da0gxbM+xVdSvw2JLZZwM7u+mdwDk91yWpZ6N+Zj+mqvYBdM9Hr7Riku1J5pPMH+TAiLuTNK6JX6Crqh1VNVdVc5vYPOndSVrBqGF/NMlWgO55f38lSZqEUcN+I3BhN30hcEM/5UialDXvZ09yDfAm4Kgke4APA5cB1yV5F/Ad4O2TLLJ1/n909WHNsFfV+SssOr3nWiRNkN+gkxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxrhkM3asMYZDrrF24Zt2aVGGHapEYZdaoRhlxph2KVGGHapEYZdaoT97Jqocfqzx+lH1/PZskuNMOxSIwy71AjDLjXCsEuNMOxSIwy71Aj72TeAtfqbN+q92fajT9eaLXuSK5PsT3LPonmXJvlukl3d46zJlilpXOs5jb8KOHOZ+R+vqm3d46Z+y5LUtzXDXlW3Ao9NoRZJEzTOBbr3JdndneYfudJKSbYnmU8yf5ADY+xO0jhGDfungZOAbcA+4KMrrVhVO6pqrqrmNrF5xN1JGtdIYa+qR6vq6ap6BrgcOLXfsiT1baSwJ9m66OW5wD0rrStpNqzZz57kGuBNwFFJ9gAfBt6UZBtQwMPARROsUVIP1gx7VZ2/zOwrJlCLpAny67JSIwy71AjDLjXCsEuNMOxSI7zFtQdr3WI66Vs5V3v/oW9/HfI21qF/9lljyy41wrBLjTDsUiMMu9QIwy41wrBLjTDsUiPsZ5+CIfvh/XfNepYtu9QIwy41wrBLjTDsUiMMu9QIwy41wrBLjbCffQYMfT+82mDLLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSI9YzZPPxwNXAscAzwI6q+kSSLcC1wAksDNt8XlX9cHKltst+ePVhPS37U8DFVfVq4I3Ae5OcAnwIuKWqTgZu6V5LmlFrhr2q9lXVXd30E8B9wHHA2cDObrWdwDmTKlLS+A7pM3uSE4DXAbcDx1TVPlj4gwAc3Xdxkvqz7rAneQnwReD9VfWjQ9hue5L5JPMHOTBKjZJ6sK6wJ9nEQtA/V1Vf6mY/mmRrt3wrsH+5batqR1XNVdXcJjb3UbOkEawZ9iQBrgDuq6qPLVp0I3BhN30hcEP/5Unqy3pucT0NuAC4O8mzfTyXAJcB1yV5F/Ad4O2TKVFr2ahDE9tlOF1rhr2qbgOywuLT+y1H0qT4DTqpEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdaoRhlxph2KVGGHapEYZdasSaQzYnOR64GjgWeAbYUVWfSHIp8G7g+92ql1TVTZMqVC88G3Vc+Y1qzbADTwEXV9VdSV4K3Jnk5m7Zx6vqLydXnqS+rBn2qtoH7Oumn0hyH3DcpAuT1K9D+sye5ATgdcDt3az3Jdmd5MokR66wzfYk80nmD3JgrGIljW7dYU/yEuCLwPur6kfAp4GTgG0stPwfXW67qtpRVXNVNbeJzT2ULGkU6wp7kk0sBP1zVfUlgKp6tKqerqpngMuBUydXpqRxrRn2JAGuAO6rqo8tmr910WrnAvf0X56kvqznavxpwAXA3Ul2dfMuAc5Psg0o4GHgoolUKKkX67kafxuQZRbZpy5tIH6DTmqEYZcaYdilRhh2qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcakaqa3s6S7wPfXjTrKOAHUyvg0MxqbbNaF1jbqPqs7ZVV9UvLLZhq2J+382S+quYGK2AVs1rbrNYF1jaqadXmabzUCMMuNWLosO8YeP+rmdXaZrUusLZRTaW2QT+zS5qeoVt2SVNi2KVGDBL2JGcm+c8kDyb50BA1rCTJw0nuTrIryfzAtVyZZH+SexbN25Lk5iQPdM/LjrE3UG2XJvlud+x2JTlroNqOT/LvSe5Lcm+SP+rmD3rsVqlrKsdt6p/ZkxwGfAt4M7AHuAM4v6q+OdVCVpDkYWCuqgb/AkaS3waeBK6uqtd08/4CeKyqLuv+UB5ZVR+ckdouBZ4cehjvbrSirYuHGQfOAd7JgMdulbrOYwrHbYiW/VTgwap6qKp+BnwBOHuAOmZeVd0KPLZk9tnAzm56Jwu/LFO3Qm0zoar2VdVd3fQTwLPDjA967FapayqGCPtxwCOLXu9htsZ7L+ArSe5Msn3oYpZxTFXtg4VfHuDogetZas1hvKdpyTDjM3PsRhn+fFxDhH25oaRmqf/vtKp6PfBW4L3d6arWZ13DeE/LMsOMz4RRhz8f1xBh3wMcv+j1K4C9A9SxrKra2z3vB65n9oaifvTZEXS75/0D1/N/ZmkY7+WGGWcGjt2Qw58PEfY7gJOTnJjkRcA7gBsHqON5khzRXTghyRHAGczeUNQ3Ahd20xcCNwxYy3PMyjDeKw0zzsDHbvDhz6tq6g/gLBauyP8X8CdD1LBCXb8CfKN73Dt0bcA1LJzWHWThjOhdwMuAW4AHuuctM1Tb3wN3A7tZCNbWgWr7LRY+Gu4GdnWPs4Y+dqvUNZXj5tdlpUb4DTqpEYZdaoRhlxph2KVGGHapEYZdaoRhlxrxv3jNRdG9OXAOAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [ "source": [
"from mindspore import context\n", "from mindspore import context\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
...@@ -180,20 +222,20 @@ ...@@ -180,20 +222,20 @@
"import numpy as np\n", "import numpy as np\n",
"import mindspore.dataset as ds\n", "import mindspore.dataset as ds\n",
"\n", "\n",
"context.set_context(mode=context.GRAPH_MODE, device_target=\"CPU\") # Windows version, set to use CPU for graph calculation\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"CPU\") \n",
"train_data_path = \"./MNIST_Data/train\"\n", "train_data_path = \"./MNIST_Data/train\"\n",
"test_data_path = \"./MNIST_Data/test\"\n", "test_data_path = \"./MNIST_Data/test\"\n",
"mnist_ds = ds.MnistDataset(train_data_path) # Load training dataset\n", "mnist_ds = ds.MnistDataset(train_data_path)\n",
"print('The type of mnist_ds:', type(mnist_ds))\n", "print('The type of mnist_ds:', type(mnist_ds))\n",
"print(\"Number of pictures contained in the mnist_ds:\",mnist_ds.get_dataset_size()) # 60000 pictures in total\n", "print(\"Number of pictures contained in the mnist_ds:\",mnist_ds.get_dataset_size())\n",
"\n", "\n",
"dic_ds = mnist_ds.create_dict_iterator() # Convert dataset to dictionary type\n", "dic_ds = mnist_ds.create_dict_iterator()\n",
"item = dic_ds.get_next()\n", "item = dic_ds.get_next()\n",
"img = item[\"image\"]\n", "img = item[\"image\"]\n",
"label = item[\"label\"]\n", "label = item[\"label\"]\n",
"\n", "\n",
"print(\"The item of mnist_ds:\", item.keys()) # Take a single data to view the data structure, including two keys, image and label\n", "print(\"The item of mnist_ds:\", item.keys())\n",
"print(\"Tensor of image in item:\", img.shape) # View the tensor of image (28,28,1)\n", "print(\"Tensor of image in item:\", img.shape) \n",
"print(\"The label of item:\", label)\n", "print(\"The label of item:\", label)\n",
"\n", "\n",
"plt.imshow(np.squeeze(img))\n", "plt.imshow(np.squeeze(img))\n",
...@@ -205,7 +247,7 @@ ...@@ -205,7 +247,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"从上面的运行情况我们可以看到,训练数据集train-images-idx3-ubyte和train-labels-idx1-ubyte对应的是6万张图片和6万个数字下标,载入数据后经过create_dict_iterator()转换字典型的数据集,取其中的一个数据查看,这是一个key为image和label的字典,其中的image的张量(高度28,宽度28,通道1)和label为对应图片的数字。" "从上面的运行情况我们可以看到,训练数据集`train-images-idx3-ubyte`和`train-labels-idx1-ubyte`对应的是6万张图片和6万个数字下标,载入数据后经过`create_dict_iterator`转换字典型的数据集,取其中的一个数据查看,这是一个key为`image`和`label`的字典,其中的`image`的张量(高度28,宽度28,通道1)和`label`为对应图片的数字。"
] ]
}, },
{ {
...@@ -219,23 +261,22 @@ ...@@ -219,23 +261,22 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"数据集对于训练非常重要,好的数据集可以有效提高训练精度和效率在加载数据集前,我们通常会对数据集进行一些处理。\n", "数据集对于训练非常重要,好的数据集可以有效提高训练精度和效率在加载数据集前,我们通常会对数据集进行一些处理。\n",
"#### 定义数据集及数据操作\n", "#### 定义数据集及数据操作\n",
"我们定义一个函数create_dataset()来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n", "我们定义一个函数`create_dataset`来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n",
"<br/>1、定义数据集。\n", "1. 定义数据集。\n",
"<br/>2、定义进行数据增强和处理所需要的一些参数。\n", "2. 定义进行数据增强和处理所需要的一些参数。\n",
"<br/>3、根据参数,生成对应的数据增强操作。\n", "3. 根据参数,生成对应的数据增强操作。\n",
"<br/>4、使用map()映射函数,将数据操作应用到数据集。\n", "4. 使用`map`映射函数,将数据操作应用到数据集。\n",
"<br/>5、对生成的数据集进行处理。" "5. 对生成的数据集进行处理。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Data processing module\n",
"import mindspore.dataset.transforms.vision.c_transforms as CV\n", "import mindspore.dataset.transforms.vision.c_transforms as CV\n",
"import mindspore.dataset.transforms.c_transforms as C\n", "import mindspore.dataset.transforms.c_transforms as C\n",
"from mindspore.dataset.transforms.vision import Inter\n", "from mindspore.dataset.transforms.vision import Inter\n",
...@@ -246,37 +287,38 @@ ...@@ -246,37 +287,38 @@
" num_parallel_workers=1):\n", " num_parallel_workers=1):\n",
" \"\"\" create dataset for train or test\n", " \"\"\" create dataset for train or test\n",
" Args:\n", " Args:\n",
" data_path: Data path\n", " data_path (str): Data path\n",
" batch_size: The number of data records in each group\n", " batch_size (int): The number of data records in each group\n",
" repeat_size: The number of replicated data records\n", " repeat_size (int): The number of replicated data records\n",
" num_parallel_workers: The number of parallel workers\n", " num_parallel_workers (int): The number of parallel workers\n",
" \"\"\"\n", " \"\"\"\n",
" # define dataset\n", " # define dataset\n",
" mnist_ds = ds.MnistDataset(data_path)\n", " mnist_ds = ds.MnistDataset(data_path)\n",
"\n", "\n",
" # Define some parameters needed for data enhancement and rough justification\n", " # define some parameters needed for data enhancement and rough justification\n",
" resize_height, resize_width = 32, 32\n", " resize_height, resize_width = 32, 32\n",
" rescale = 1.0 / 255.0\n", " rescale = 1.0 / 255.0\n",
" shift = 0.0\n", " shift = 0.0\n",
" rescale_nml = 1 / 0.3081\n", " rescale_nml = 1 / 0.3081\n",
" shift_nml = -1 * 0.1307 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n",
"\n", "\n",
" # According to the parameters, generate the corresponding data enhancement method\n", " # according to the parameters, generate the corresponding data enhancement method\n",
" resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) by bilinear interpolation\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
" rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
" rescale_op = CV.Rescale(rescale, shift) # rescale images\n", " rescale_op = CV.Rescale(rescale, shift)\n",
" hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n", " hwc2chw_op = CV.HWC2CHW()\n",
" type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n", " type_cast_op = C.TypeCast(mstype.int32)\n",
"\n", "\n",
" # Using map () to apply operations to a dataset\n", " # using map to apply operations to a dataset\n",
" mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
" # Process the generated dataset\n", " \n",
" # process the generated dataset\n",
" buffer_size = 10000\n", " buffer_size = 10000\n",
" mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)\n",
" mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
" mnist_ds = mnist_ds.repeat(repeat_size)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n",
"\n", "\n",
...@@ -287,10 +329,11 @@ ...@@ -287,10 +329,11 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"其中<br/>\n", "其中\n",
"batch_size:每组包含的数据个数,现设置每组包含32个数据。\n", "- `batch_size`:每组包含的数据个数,现设置每组包含32个数据。\n",
"<br/>repeat_size:数据集复制的数量。\n", "- `repeat_size`:数据集复制的数量。\n",
"<br/>先进行shuffle、batch操作,再进行repeat操作,这样能保证1个epoch内数据不重复。" "\n",
"先进行`shuffle`、`batch`操作,再进行`repeat`操作,这样能保证1个`epoch`内数据不重复。"
] ]
}, },
{ {
...@@ -309,34 +352,52 @@ ...@@ -309,34 +352,52 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of groups in the dataset: 1875\n"
]
}
],
"source": [ "source": [
"datas = create_dataset(train_data_path) # Process the train dataset\n", "datas = create_dataset(train_data_path)\n",
"print('Number of groups in the dataset:', datas.get_dataset_size()) # Number of query dataset groups" "print('Number of groups in the dataset:', datas.get_dataset_size())"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"其次,取出其中一组数据,查看包含的key,图片数据的张量,以及下标labels的值。" "其次,取出其中一组数据,查看包含的`key`,图片数据的张量,以及下标`labels`的值。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"metadata": { "metadata": {
"scrolled": false "scrolled": false
}, },
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['label', 'image'])\n",
"Tensor of image: (32, 1, 32, 32)\n",
"labels: [7 4 0 6 6 6 5 8 3 5 4 8 5 3 2 9 3 7 3 0 9 2 0 6 3 3 6 2 5 9 2 0]\n"
]
}
],
"source": [ "source": [
"data = datas.create_dict_iterator().get_next() # Take a set of datasets\n", "data = datas.create_dict_iterator().get_next()\n",
"print(data.keys())\n", "print(data.keys())\n",
"images = data[\"image\"] # Take out the image data in this dataset\n", "images = data[\"image\"] \n",
"labels = data[\"label\"] # Take out the label (subscript) of this data set\n", "labels = data[\"label\"] \n",
"print('Tensor of image:', images.shape) # Query the tensor of images in each dataset (32,1,32,32)\n", "print('Tensor of image:', images.shape)\n",
"print('labels:', labels)" "print('labels:', labels)"
] ]
}, },
...@@ -344,14 +405,27 @@ ...@@ -344,14 +405,27 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"最后,查看image的图像和下标对应的值。" "最后,查看`image`的图像和下标对应的值。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [ "source": [
"count = 1\n", "count = 1\n",
"for i in images:\n", "for i in images:\n",
...@@ -361,7 +435,7 @@ ...@@ -361,7 +435,7 @@
" plt.xticks([])\n", " plt.xticks([])\n",
" count += 1\n", " count += 1\n",
" plt.axis(\"off\")\n", " plt.axis(\"off\")\n",
"plt.show() # Print a total of 32 pictures in the group" "plt.show()"
] ]
}, },
{ {
...@@ -375,7 +449,7 @@ ...@@ -375,7 +449,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 三、构造神经网络" "## 构造神经网络"
] ]
}, },
{ {
...@@ -396,7 +470,7 @@ ...@@ -396,7 +470,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"<img src=\"https://img-blog.csdnimg.cn/20190305161316701.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L21tbV9qc3c=,size_16,color_FFFFFF,t_70\" alt=\"LeNet5\">" "<img src=\"https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg\" alt=\"LeNet5\">"
] ]
}, },
{ {
...@@ -405,21 +479,21 @@ ...@@ -405,21 +479,21 @@
"source": [ "source": [
"在构建LeNet5前,我们需要对全连接层以及卷积层进行初始化。\n", "在构建LeNet5前,我们需要对全连接层以及卷积层进行初始化。\n",
"\n", "\n",
"TruncatedNormal:参数初始化方法,MindSpore支持TruncatedNormal、Normal、Uniform等多种参数初始化方法,具体可以参考MindSpore API的mindspore.common.initializer模块说明。\n", "`TruncatedNormal`:参数初始化方法,MindSpore支持`TruncatedNormal`、`Normal`、`Uniform`等多种参数初始化方法,具体可以参考MindSpore API的`mindspore.common.initializer`模块说明。\n",
"\n", "\n",
"初始化示例代码如下:" "初始化示例代码如下:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import mindspore.nn as nn\n", "import mindspore.nn as nn\n",
"from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.common.initializer import TruncatedNormal\n",
"\n", "\n",
"# Initialize 2D convolution function\n", "# initialize 2D convolution function\n",
"def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
" \"\"\"Conv layer weight initial.\"\"\"\n", " \"\"\"Conv layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
...@@ -427,14 +501,14 @@ ...@@ -427,14 +501,14 @@
" kernel_size=kernel_size, stride=stride, padding=padding,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n",
" weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
"\n", "\n",
"# Initialize full connection layer\n", "# initialize full connection layer\n",
"def fc_with_initialize(input_channels, out_channels):\n", "def fc_with_initialize(input_channels, out_channels):\n",
" \"\"\"Fc layer weight initial.\"\"\"\n", " \"\"\"Fc layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
" bias = weight_variable()\n", " bias = weight_variable()\n",
" return nn.Dense(input_channels, out_channels, weight, bias)\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n",
"\n", "\n",
"# Set truncated normal distribution\n", "# set truncated normal distribution\n",
"def weight_variable():\n", "def weight_variable():\n",
" \"\"\"Weight initial.\"\"\"\n", " \"\"\"Weight initial.\"\"\"\n",
" return TruncatedNormal(0.02)" " return TruncatedNormal(0.02)"
...@@ -444,14 +518,14 @@ ...@@ -444,14 +518,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"使用MindSpore定义神经网络需要继承mindspore.nn.cell.Cell,Cell是所有神经网络(Conv2d等)的基类。\n", "使用MindSpore定义神经网络需要继承`mindspore.nn.cell.Cell`,`Cell`是所有神经网络(`Conv2d`等)的基类。\n",
"\n", "\n",
"神经网络的各层需要预先在\\_\\_init\\_\\_()方法中定义,然后通过定义construct()方法来完成神经网络的前向构造,按照LeNet5的网络结构,定义网络各层如下:" "神经网络的各层需要预先在`__init__`方法中定义,然后通过定义`construct`方法来完成神经网络的前向构造,按照LeNet5的网络结构,定义网络各层如下:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -460,9 +534,9 @@ ...@@ -460,9 +534,9 @@
" # define the operator required\n", " # define the operator required\n",
" def __init__(self):\n", " def __init__(self):\n",
" super(LeNet5, self).__init__()\n", " super(LeNet5, self).__init__()\n",
" self.batch_size = 32 # 32 pictures in each group\n", " self.batch_size = 32\n",
" self.conv1 = conv(1, 6, 5) # Convolution layer 1, 1 channel input (1 Figure), 6 channel output (6 figures), convolution core 5 * 5\n", " self.conv1 = conv(1, 6, 5)\n",
" self.conv2 = conv(6, 16, 5) # Convolution layer 2,6-channel input, 16 channel output, convolution kernel 5 * 5\n", " self.conv2 = conv(6, 16, 5)\n",
" self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
" self.fc2 = fc_with_initialize(120, 84)\n", " self.fc2 = fc_with_initialize(120, 84)\n",
" self.fc3 = fc_with_initialize(84, 10)\n", " self.fc3 = fc_with_initialize(84, 10)\n",
...@@ -472,18 +546,18 @@ ...@@ -472,18 +546,18 @@
"\n", "\n",
" # use the preceding operators to construct networks\n", " # use the preceding operators to construct networks\n",
" def construct(self, x):\n", " def construct(self, x):\n",
" x = self.conv1(x) # 1*32*32-->6*28*28\n", " x = self.conv1(x)\n",
" x = self.relu(x) # 6*28*28-->6*14*14\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.conv2(x) # Convolution layer\n", " x = self.conv2(x) \n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.flatten(x) # Dimensionality reduction\n", " x = self.flatten(x)\n",
" x = self.fc1(x) # Full connection\n", " x = self.fc1(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc2(x) # Full connection\n", " x = self.fc2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc3(x) # Full connection\n", " x = self.fc3(x) \n",
" return x" " return x"
] ]
}, },
...@@ -496,19 +570,57 @@ ...@@ -496,19 +570,57 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"LeNet5<\n",
" (conv1): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5),stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False,weight_init=Parameter (name=conv1.weight), bias_init=None>\n",
" (conv2): Conv2d<input_channels=6, output_channels=16, kernel_size=(5, 5),stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False,weight_init=Parameter (name=conv2.weight), bias_init=None>\n",
" (fc1): Dense<in_channels=400, out_channels=120, weight=Parameter (name=fc1.weight), has_bias=True, bias=Parameter (name=fc1.bias)>\n",
" (fc2): Dense<in_channels=120, out_channels=84, weight=Parameter (name=fc2.weight), has_bias=True, bias=Parameter (name=fc2.bias)>\n",
" (fc3): Dense<in_channels=84, out_channels=10, weight=Parameter (name=fc3.weight), has_bias=True, bias=Parameter (name=fc3.bias)>\n",
" (relu): ReLU<>\n",
" (max_pool2d): MaxPool2d<kernel_size=2, stride=2, pad_mode=VALID>\n",
" (flatten): Flatten<>\n",
" >"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"network = LeNet5()\n", "network = LeNet5()\n",
"print(network)" "network"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"[Parameter (name=conv1.weight),\n",
" Parameter (name=conv2.weight),\n",
" Parameter (name=fc1.weight),\n",
" Parameter (name=fc1.bias),\n",
" Parameter (name=fc2.weight),\n",
" Parameter (name=fc2.bias),\n",
" Parameter (name=fc3.weight),\n",
" Parameter (name=fc3.bias)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"param = network.trainable_params()\n", "param = network.trainable_params()\n",
"param" "param"
...@@ -518,61 +630,59 @@ ...@@ -518,61 +630,59 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 四、搭建训练网络并进行训练" "## 搭建训练网络并进行训练"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"构建完成神经网络后,就可以着手进行训练网络的构建,模型训练函数为Model.train(),参数主要包含:\n", "构建完成神经网络后,就可以着手进行训练网络的构建,模型训练函数为`Model.train`,参数主要包含:\n",
"<br/>1、圈数epoch size(每圈需要遍历完成1875组图片);\n", "1. 每个`epoch`需要遍历完成图片的`batch`数:`epoch_size`;\n",
"<br/>2、数据集ds_train;\n", "2. 数据集`ds_train`;\n",
"<br/>3、回调函数callbacks包含ModelCheckpoint、LossMonitor、SummaryStepckpoint_cb,Callback模型检测参数;\n", "3. 回调函数`callbacks`包含`ModelCheckpoint`、`LossMonitor`和`Callback`模型检测参数;\n",
"<br/>4、底层数据通道dataset_sink_mode,此参数默认True需设置成False,因为此功能不支持CPU模式。" "4. 数据下沉模式`dataset_sink_mode`,此参数默认`True`需设置成`False`,因为此功能不支持CPU模式。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Training and testing related modules\n", "# training related modules\n",
"import argparse\n",
"from mindspore import Tensor\n", "from mindspore import Tensor\n",
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor,SummaryStep,Callback\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor,Callback\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"from mindspore.nn.metrics import Accuracy\n", "from mindspore.nn.metrics import Accuracy\n",
"from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n",
"\n", "\n",
"def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, step_loss_info):\n", "def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, step_loss_info):\n",
" \"\"\"Define the training method.\"\"\"\n", " \"\"\"Define the training method.\"\"\"\n",
" print(\"============== Starting Training ==============\")\n", " print(\"============== Starting Training ==============\")\n",
" # load training dataset\n", " # load training dataset\n",
" ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n", " ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n",
" model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), step_loss_info], dataset_sink_mode=False)" " model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_info], dataset_sink_mode=False)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"自定义一个存储每一步训练的step和对应loss值的类Step_loss_info(),并继承了Callback类,可以自定义训练过程中的处理措施,非常方便,等训练完成后,可将数据绘图查看loss的变化情况。" "自定义一个存储每一步训练的`step`和对应loss值的类`Step_loss_info`,并继承了`Callback`类,可以自定义训练过程中的处理措施,非常方便,等训练完成后,可将数据绘图查看loss的变化情况。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Custom callback function\n", "# custom callback function\n",
"class Step_loss_info(Callback):\n", "class StepLossInfo(Callback):\n",
" def step_end(self, run_context):\n", " def step_end(self, run_context):\n",
" cb_params = run_context.original_args()\n", " cb_params = run_context.original_args()\n",
" # step_ Loss dictionary for saving loss value and step number information\n", " # step_loss dictionary for saving loss value and step number information\n",
" step_loss[\"loss_value\"].append(str(cb_params.net_outputs))\n", " step_loss[\"loss_value\"].append(str(cb_params.net_outputs))\n",
" step_loss[\"step\"].append(str(cb_params.cur_step_num))" " step_loss[\"step\"].append(str(cb_params.cur_step_num))"
] ]
...@@ -582,32 +692,63 @@ ...@@ -582,32 +692,63 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 定义损失函数及优化器\n", "### 定义损失函数及优化器\n",
"基本概念\n", "\n",
"在进行定义之前,先简单介绍损失函数及优化器的概念。\n", "在进行定义之前,先简单介绍损失函数及优化器的概念。\n",
"<br/>损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失函数的值。定义一个好的损失函数,可以有效提高模型的性能。\n", "\n",
"<br/>优化器:用于最小化损失函数,从而在训练过程中改进模型。\n", "损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失函数的值。定义一个好的损失函数,可以有效提高模型的性能。\n",
"<br/>定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。\n", "\n",
"<br/>定义损失函数。\n", "优化器:用于最小化损失函数,从而在训练过程中改进模型。\n",
"<br/>MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。这里使用SoftmaxCrossEntropyWithLogits损失函数。" "\n",
"定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。\n",
"\n",
"MindSpore支持的损失函数有`SoftmaxCrossEntropyWithLogits`、`L1Loss`、`MSELoss`等。这里使用`SoftmaxCrossEntropyWithLogits`损失函数。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 14,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============== Starting Training ==============\n",
"epoch: 1 step 125, loss is 2.3136098384857178\n",
"epoch: 1 step 250, loss is 2.303882598876953\n",
"epoch: 1 step 375, loss is 2.3046326637268066\n",
"epoch: 1 step 500, loss is 2.3024802207946777\n",
"epoch: 1 step 625, loss is 2.3106091022491455\n",
"epoch: 1 step 750, loss is 2.298833131790161\n",
"epoch: 1 step 875, loss is 2.3070852756500244\n",
"epoch: 1 step 1000, loss is 2.284291982650757\n",
"epoch: 1 step 1125, loss is 0.7130898237228394\n",
"epoch: 1 step 1250, loss is 0.17307262122631073\n",
"epoch: 1 step 1375, loss is 0.3248927891254425\n",
"epoch: 1 step 1500, loss is 0.09352534264326096\n",
"epoch: 1 step 1625, loss is 0.025928258895874023\n",
"epoch: 1 step 1750, loss is 0.0918595939874649\n",
"epoch: 1 step 1875, loss is 0.20610764622688293\n",
"Epoch time: 15709.893, per step time: 8.379, avg loss: 1.440\n",
"************************************************************\n"
]
}
],
"source": [ "source": [
"import os\n", "import os\n",
"from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n",
"\n", "\n",
"if os.name == \"nt\":\n", "if os.name == \"nt\":\n",
" os.system('del/f/s/q *.ckpt *.meta')# Clean up old run files before in Windows\n", " # clean up old run files before in Windows\n",
" os.system('del/f/s/q *.ckpt *.meta')\n",
"else:\n", "else:\n",
" os.system('rm -f *.ckpt *.meta *.pb')# Clean up old run files before in Linux\n", " # clean up old run files before in Linux\n",
" os.system('rm -f *.ckpt *.meta *.pb')\n",
"\n", "\n",
"lr = 0.01 # learning rate\n", "lr = 0.01\n",
"momentum = 0.9 #\n", "momentum = 0.9 \n",
"\n", "\n",
"# create the network\n", "# create the network\n",
"network = LeNet5()\n", "network = LeNet5()\n",
...@@ -615,25 +756,23 @@ ...@@ -615,25 +756,23 @@
"# define the optimizer\n", "# define the optimizer\n",
"net_opt = nn.Momentum(network.trainable_params(), lr, momentum)\n", "net_opt = nn.Momentum(network.trainable_params(), lr, momentum)\n",
"\n", "\n",
"\n",
"# define the loss function\n", "# define the loss function\n",
"net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", "net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
"\n",
"# define the model\n", "# define the model\n",
"model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()} )\n", "model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()} )\n",
"\n", "\n",
"\n",
"epoch_size = 1\n", "epoch_size = 1\n",
"mnist_path = \"./MNIST_Data\"\n", "mnist_path = \"./MNIST_Data\"\n",
"\n",
"config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=16)\n",
"# save the network model and parameters for subsequence fine-tuning\n", "# save the network model and parameters for subsequence fine-tuning\n",
"\n", "config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=16)\n",
"ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)\n",
"# group layers into an object with training and evaluation features\n", "# group layers into an object with training and evaluation features\n",
"ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)\n",
"# define step_loss dictionary for saving loss value and step number information\n",
"step_loss = {\"step\": [], \"loss_value\": []}\n", "step_loss = {\"step\": [], \"loss_value\": []}\n",
"# step_ Loss dictionary for saving loss value and step number information\n", "# save the steps and loss informations\n",
"step_loss_info = Step_loss_info()\n", "step_loss_info = StepLossInfo()\n",
"# save the steps and loss value\n", "\n",
"repeat_size = 1\n", "repeat_size = 1\n",
"train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, step_loss_info)\n" "train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, step_loss_info)\n"
] ]
...@@ -642,7 +781,7 @@ ...@@ -642,7 +781,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"训练完成后,能在Jupyter的工作路径上生成多个模型文件,名称具体含义checkpoint_{网络名称}-{第几个epoch}_{第几个step}.ckpt。" "训练完成后,能在Jupyter的工作路径上生成多个模型文件,名称具体含义`checkpoint_{网络名称}-{第几个epoch}_{第几个step}.ckpt`。"
] ]
}, },
{ {
...@@ -654,11 +793,24 @@ ...@@ -654,11 +793,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 15,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
"outputs": [], "outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [ "source": [
"steps = step_loss[\"step\"]\n", "steps = step_loss[\"step\"]\n",
"loss_value = step_loss[\"loss_value\"]\n", "loss_value = step_loss[\"loss_value\"]\n",
...@@ -688,22 +840,39 @@ ...@@ -688,22 +840,39 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 五、数据测试验证模型精度" "## 数据测试验证模型精度"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"搭建测试网络的过程主要为:<br/>1、载入模型.cptk文件中的参数param;<br/>2、将参数param载入到神经网络LeNet5中;<br/>3、载入测试数据集;<br/>4、调用函数model.eval()传入参数测试数据集ds_eval,就生成模型checkpoint_lenet-1_1875.ckpt的精度值。<br/>dataset_sink_mode表示数据集下沉模式,不支持CPU,所以这里设置成False。" "搭建测试网络的过程主要为:\n",
"\n",
"1. 载入模型`.cptk`文件中的参数`param`;\n",
"2. 将参数`param`载入到神经网络LeNet5中;\n",
"3. 载入测试数据集;\n",
"4. 调用函数`model.eval`传入参数测试数据集`ds_eval`,就生成模型`checkpoint_lenet-1_1875.ckpt`的精度值。\n",
"\n",
"> `dataset_sink_mode`表示数据集下沉模式,不支持CPU,所以这里设置成`False`。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============== Starting Testing ==============\n",
"============== Accuracy:{'Accuracy': 0.9553285256410257} ==============\n"
]
}
],
"source": [ "source": [
"# testing relate modules \n",
"def test_net(network, model, mnist_path):\n", "def test_net(network, model, mnist_path):\n",
" \"\"\"Define the evaluation method.\"\"\"\n", " \"\"\"Define the evaluation method.\"\"\"\n",
" print(\"============== Starting Testing ==============\")\n", " print(\"============== Starting Testing ==============\")\n",
...@@ -731,14 +900,27 @@ ...@@ -731,14 +900,27 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"acc_model_info()函数是将每125步的保存的模型,调用model.eval()函数将测试出的精度返回到步数列表和精度列表,如下:" "`acc_model_info`函数是将每125步的保存的模型,调用`model.eval`函数将测试出的精度返回到步数列表和精度列表,如下:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [ "source": [
"def acc_model_info(network, model, mnist_path, model_numbers):\n", "def acc_model_info(network, model, mnist_path, model_numbers):\n",
" \"\"\"Define the plot info method\"\"\"\n", " \"\"\"Define the plot info method\"\"\"\n",
...@@ -756,7 +938,7 @@ ...@@ -756,7 +938,7 @@
" step_list.append(i*125)\n", " step_list.append(i*125)\n",
" return step_list,acc_list\n", " return step_list,acc_list\n",
"\n", "\n",
"# Draw line chart according to training steps and model accuracy\n", "# draw line chart according to training steps and model accuracy\n",
"l1,l2 = acc_model_info(network, model, mnist_path, 15)\n", "l1,l2 = acc_model_info(network, model, mnist_path, 15)\n",
"plt.xlabel(\"Model of Steps\")\n", "plt.xlabel(\"Model of Steps\")\n",
"plt.ylabel(\"Model accuracy\")\n", "plt.ylabel(\"Model accuracy\")\n",
...@@ -776,7 +958,7 @@ ...@@ -776,7 +958,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 六、模型预测应用" "## 模型预测应用"
] ]
}, },
{ {
...@@ -790,32 +972,56 @@ ...@@ -790,32 +972,56 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"1需要将要测试的数据转换成适应LeNet5的数据类型。\n", "1. 需要将要测试的数据转换成适应LeNet5的数据类型。\n",
"<br/>2、提取出image的数据。\n", "2. 提取出`image`的数据。\n",
"<br/>3、使用函数model.predict()预测image对应的数字。需要说明的是predict返回的是image对应0-9的概率值。\n", "3. 使用函数`model.predict`预测`image`对应的数字。需要说明的是`predict`返回的是`image`对应0-9的概率值。\n",
"<br/>4、调用plot_pie()将预测的各数字的概率显示出来。负概率的数字会被去掉。" "4. 调用`plot_pie`将预测的各数字的概率显示出来。负概率的数字会被去掉。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"载入要测试的数据集并调用create_dataset()转换成符合格式要求的数据集,并选取其中一组32张图片进行预测。" "载入要测试的数据集并调用`create_dataset`转换成符合格式要求的数据集,并选取其中一组32张图片进行预测。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Row 1, column 7 is incorrectly identified as 2, the correct value should be 3 \n",
"\n",
"Row 4, column 3 is incorrectly identified as 0, the correct value should be 8 \n",
"\n",
"[2 7 7 0 6 3 2 3 2 5 2 1 8 7 8 3 0 5 2 1 0 8 2 2 1 8 0 3 6 8 8 2] <--Predicted figures\n",
"[2 7 7 0 6 3 3 3 2 5 2 1 8 7 8 3 0 5 2 1 0 8 2 2 1 8 8 3 6 8 8 2] <--The right number\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [ "source": [
"ds_test = create_dataset(test_data_path).create_dict_iterator()\n", "ds_test = create_dataset(test_data_path).create_dict_iterator()\n",
"data = ds_test.get_next()\n", "data = ds_test.get_next()\n",
"images = data[\"image\"]\n", "images = data[\"image\"]\n",
"labels = data[\"label\"] # The subscript of data picture is the standard for us to judge whether it is correct or not\n", "labels = data[\"label\"]\n",
"\n", "\n",
"output =model.predict(Tensor(data['image']))\n", "output =model.predict(Tensor(data['image']))\n",
"# The predict function returns the probability of 0-9 numbers corresponding to each picture\n",
"prb = output.asnumpy()\n", "prb = output.asnumpy()\n",
"pred = np.argmax(output.asnumpy(), axis=1)\n", "pred = np.argmax(output.asnumpy(), axis=1)\n",
"err_num = []\n", "err_num = []\n",
...@@ -828,12 +1034,11 @@ ...@@ -828,12 +1034,11 @@
" plt.axis(\"off\")\n", " plt.axis(\"off\")\n",
" if color == 'red':\n", " if color == 'red':\n",
" index = 0\n", " index = 0\n",
" # Print out the wrong data identified by the current group\n",
" print(\"Row {}, column {} is incorrectly identified as {}, the correct value should be {}\".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\\n')\n", " print(\"Row {}, column {} is incorrectly identified as {}, the correct value should be {}\".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\\n')\n",
"if index:\n", "if index:\n",
" print(\"All the figures in this group are predicted correctly\")\n", " print(\"All the figures in this group are predicted correctly!\")\n",
"print(pred, \"<--Predicted figures\") # Print the numbers recognized by each group of pictures\n", "print(pred, \"<--Predicted figures\") \n",
"print(labels, \"<--The right number\") # Print the subscript corresponding to each group of pictures\n", "print(labels, \"<--The right number\")\n",
"plt.show()" "plt.show()"
] ]
}, },
...@@ -843,29 +1048,68 @@ ...@@ -843,29 +1048,68 @@
"source": [ "source": [
"构建一个概率分析的饼图函数。\n", "构建一个概率分析的饼图函数。\n",
"\n", "\n",
"备注:prb为上一段代码中,存储这组数对应的数字概率。" "备注:`prb`为上一段代码中,存储这组数对应的数字概率。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Figure 1 probability of corresponding numbers [0-9]:\n",
" [-0.48477417 2.0016153 11.054499 2.3544474 -2.7436607 -3.630352\n",
" -3.7523592 0.9330094 2.6389365 -6.602851 ]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Figure 2 probability of corresponding numbers [0-9]:\n",
" [-1.916862 -0.16948226 -0.2352289 -0.5903556 0.8726251 -0.41480547\n",
" -3.0238853 4.210627 -0.70848167 1.8492212 ]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"# define the pie drawing function of probability analysis\n", "# define the pie drawing function of probability analysis\n",
"def plot_pie(prbs):\n", "def plot_pie(prbs):\n",
" dict1 = {}\n", " dict1 = {}\n",
" # Remove the negative number and build the dictionary dict1. The key is the number and the value is the probability value\n", " # remove the negative number and build the dictionary dict1. The key is the number and the value is the probability value\n",
" for i in range(10):\n", " for i in range(10):\n",
" if prbs[i] > 0:\n", " if prbs[i] > 0:\n",
" dict1[str(i)] = prbs[i]\n", " dict1[str(i)] = prbs[i]\n",
"\n", "\n",
" label_list = dict1.keys() # Label of each part\n", " label_list = dict1.keys()\n",
" size = dict1.values() # Size of each part\n", " size = dict1.values()\n",
" colors = [\"red\", \"green\", \"pink\", \"blue\", \"purple\", \"orange\", \"gray\"] # Building a round cake pigment Library\n", " colors = [\"red\", \"green\", \"pink\", \"blue\", \"purple\", \"orange\", \"gray\"] \n",
" color = colors[: len(size)]# Color of each part\n", " color = colors[: len(size)]\n",
" plt.pie(size, colors=color, labels=label_list, labeldistance=1.1, autopct=\"%1.1f%%\", shadow=False, startangle=90, pctdistance=0.6)\n", " plt.pie(size, colors=color, labels=label_list, labeldistance=1.1, autopct=\"%1.1f%%\", shadow=False, startangle=90, pctdistance=0.6)\n",
" plt.axis(\"equal\") # Set the scale size of x-axis and y-axis to be equal\n", " plt.axis(\"equal\")\n",
" plt.legend()\n", " plt.legend()\n",
" plt.title(\"Image classification\")\n", " plt.title(\"Image classification\")\n",
" plt.show()\n", " plt.show()\n",
...@@ -886,9 +1130,9 @@ ...@@ -886,9 +1130,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python [conda env:root] *", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "conda-root-py" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册