提交 77cd1388 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!507 Unify code formats in notebook for r0.5

Merge pull request !507 from lvmingfu/r0.5
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
"接下来我们用一个图片分类的项目来体验计算图与数据图的生成与使用。\n", "接下来我们用一个图片分类的项目来体验计算图与数据图的生成与使用。\n",
" \n", " \n",
"## 本次体验的整体流程\n", "## 本次体验的整体流程\n",
"1体验模型的数据选择使用MNIST数据集,MNIST数据集整体数据量比较小,更适合体验使用。\n", "1. 体验模型的数据选择使用MNIST数据集,MNIST数据集整体数据量比较小,更适合体验使用。\n",
"\n", "\n",
"2初始化一个网络,本次的体验使用LeNet网络。\n", "2. 初始化一个网络,本次的体验使用LeNet网络。\n",
"\n", "\n",
"3增加可视化功能的使用,并设定只记录计算图与数据图。\n", "3. 增加可视化功能的使用,并设定只记录计算图与数据图。\n",
"\n", "\n",
"4加载训练数据集并进行训练,训练完成后,查看结果并保存模型文件。\n", "4. 加载训练数据集并进行训练,训练完成后,查看结果并保存模型文件。\n",
"\n", "\n",
"5启用MindInsight的可视化图界面,进行训练过程的核对。" "5. 启用MindInsight的可视化图界面,进行训练过程的核对。"
] ]
}, },
{ {
...@@ -35,8 +35,8 @@ ...@@ -35,8 +35,8 @@
"\n", "\n",
"从以下网址下载,并将数据包解压后放在Jupyter的工作目录下。\n", "从以下网址下载,并将数据包解压后放在Jupyter的工作目录下。\n",
"\n", "\n",
"- 训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\",\"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "- 训练数据集:{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
"- 测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\",\"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n", "- 测试数据集:{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}\n",
"\n", "\n",
"可执行下面代码查看Jupyter的工作目录。" "可执行下面代码查看Jupyter的工作目录。"
] ]
...@@ -187,11 +187,11 @@ ...@@ -187,11 +187,11 @@
"source": [ "source": [
"### 可视化操作流程\n", "### 可视化操作流程\n",
"\n", "\n",
"1准备训练脚本,在训练脚本中指定计算图的超参数信息,使用`Summary`保存到日志中,接着再运行训练脚本。\n", "1. 准备训练脚本,在训练脚本中指定计算图的超参数信息,使用`Summary`保存到日志中,接着再运行训练脚本。\n",
"\n", "\n",
"2启动MindInsight,启动成功后,就可以通过访问命令执行后显示的地址,查看可视化界面。\n", "2. 启动MindInsight,启动成功后,就可以通过访问命令执行后显示的地址,查看可视化界面。\n",
"\n", "\n",
"3访问可视化地址成功后,就可以对图界面进行查询等操作。" "3. 访问可视化地址成功后,就可以对图界面进行查询等操作。"
] ]
}, },
{ {
...@@ -200,11 +200,11 @@ ...@@ -200,11 +200,11 @@
"source": [ "source": [
"#### 初始化网络\n", "#### 初始化网络\n",
"\n", "\n",
"1导入构建网络所使用的模块。\n", "1. 导入构建网络所使用的模块。\n",
"\n", "\n",
"2构建初始化参数的函数。\n", "2. 构建初始化参数的函数。\n",
"\n", "\n",
"3创建网络,在网络中设置参数。" "3. 创建网络,在网络中设置参数。"
] ]
}, },
{ {
...@@ -273,11 +273,11 @@ ...@@ -273,11 +273,11 @@
"source": [ "source": [
"#### 主程序运行\n", "#### 主程序运行\n",
"\n", "\n",
"1首先在主函数之前调用所需要的模块,并在主函数之前使用相应接口。\n", "1. 首先在主函数之前调用所需要的模块,并在主函数之前使用相应接口。\n",
"\n", "\n",
"2本次体验主要完成计算图与数据图的可视化,定义变量`specified={'collect_graph': True,'collect_dataset_graph': True}`,在`specified`字典中,键名`collect_graph`值设置为`True`,表示记录计算图;键名`collect_dataset_graph`值设置为`True`,表示记录数据图。\n", "2. 本次体验主要完成计算图与数据图的可视化,定义变量`specified={'collect_graph': True,'collect_dataset_graph': True}`,在`specified`字典中,键名`collect_graph`值设置为`True`,表示记录计算图;键名`collect_dataset_graph`值设置为`True`,表示记录数据图。\n",
"\n", "\n",
"3定义完`specified`变量后,传参到`summary_collector`中,最后将`summary_collector`传参到`model`中。\n", "3. 定义完`specified`变量后,传参到`summary_collector`中,最后将`summary_collector`传参到`model`中。\n",
"\n", "\n",
"至此,模型中就有了计算图与数据图的可视化功能。" "至此,模型中就有了计算图与数据图的可视化功能。"
] ]
...@@ -332,7 +332,7 @@ ...@@ -332,7 +332,7 @@
"- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n", "- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n",
"- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n", "- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n",
"\n", "\n",
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/005.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/mindinsight_map.png)"
] ]
}, },
{ {
...@@ -346,7 +346,7 @@ ...@@ -346,7 +346,7 @@
"- 节点信息:显示当前所查看节点的信息,包括名称、类型、属性、输入和输出。便于在训练结束后,核对计算正确性时查看。\n", "- 节点信息:显示当前所查看节点的信息,包括名称、类型、属性、输入和输出。便于在训练结束后,核对计算正确性时查看。\n",
"- 图例:图例中包括命名空间、聚合节点、虚拟节点、算子节点、常量节点,通过不同图形来区分。\n", "- 图例:图例中包括命名空间、聚合节点、虚拟节点、算子节点、常量节点,通过不同图形来区分。\n",
"\n", "\n",
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/004.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/cast_map.png)"
] ]
}, },
{ {
...@@ -357,11 +357,11 @@ ...@@ -357,11 +357,11 @@
"\n", "\n",
"数据图所展示的顺序与数据集使用处代码顺序对应\n", "数据图所展示的顺序与数据集使用处代码顺序对应\n",
"\n", "\n",
"1首先是从加载数据集`mnist_ds = ds.MnistDataset(data_path)`开始,对应数据图中`MnistDataset`。\n", "1. 首先是从加载数据集`mnist_ds = ds.MnistDataset(data_path)`开始,对应数据图中`MnistDataset`。\n",
"\n", "\n",
"2在以下所示代码中,是数据预处理的一些方法,顺序与数据图中所示顺序对应。\n", "2. 在以下所示代码中,是数据预处理的一些方法,顺序与数据图中所示顺序对应。\n",
"\n", "\n",
"`\n", "```\n",
"type_cast_op = C.TypeCast(mstype.int32)\n", "type_cast_op = C.TypeCast(mstype.int32)\n",
"resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n", "resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
"rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n", "rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
...@@ -372,21 +372,22 @@ ...@@ -372,21 +372,22 @@
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
"`\n", "```\n",
"\n", "\n",
"- `TypeCast`:在数据集`create_data`函数中,使用:`TypeCase(mstype.int32)`,将数据类型转换成我们所设置的类型。\n", "- `TypeCast`:在数据集`create_data`函数中,使用:`TypeCase(mstype.int32)`,将数据类型转换成我们所设置的类型。\n",
"- `Resize`:在数据集`create_data`函数中,使用:`Resize(resize_height,resize_width = 32,32)`,可以将数据的高和宽做调整。\n", "- `Resize`:在数据集`create_data`函数中,使用:`Resize(resize_height,resize_width = 32,32)`,可以将数据的高和宽做调整。\n",
"- `Rescale`:在数据集`create_data`函数中,使用:`rescale = 1.0 / 255.0`;`Rescale(rescale,shift)`,可以重新数据格式。\n", "- `Rescale`:在数据集`create_data`函数中,使用:`rescale = 1.0 / 255.0`;`Rescale(rescale,shift)`,可以重新数据格式。\n",
"- `HWC2CHW`:在数据集`create_data`函数中,使用:`HWC2CHW()`,此方法可以将数据所带信息与通道结合,一并加载。\n", "- `HWC2CHW`:在数据集`create_data`函数中,使用:`HWC2CHW()`,此方法可以将数据所带信息与通道结合,一并加载。\n",
"\n", "\n",
"3、前面的几个步骤是数据集的预处理顺序,后面几个步骤是模型加载数据集时要定义的参数,顺序与数据图中对应。\n",
"\n", "\n",
"`\n", "3. 前面的几个步骤是数据集的预处理顺序,后面几个步骤是模型加载数据集时要定义的参数,顺序与数据图中对应。\n",
"\n",
"```\n",
"buffer_size = 10000\n", "buffer_size = 10000\n",
"mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", "mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
"mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", "mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
"mnist_ds = mnist_ds.repeat(repeat_size)\n", "mnist_ds = mnist_ds.repeat(repeat_size)\n",
"`\n", "```\n",
" \n", " \n",
"- `Shuffle`:在数据集`create_data`函数中,使用:`buffer_size = 10000`,后面数值可以支持自行设置,表示一次缓存数据的数量。\n", "- `Shuffle`:在数据集`create_data`函数中,使用:`buffer_size = 10000`,后面数值可以支持自行设置,表示一次缓存数据的数量。\n",
"- `Batch`:在数据集`create_data`函数中,使用:`batch_size = 32`。支持自行设置,表示将整体数据集划分成小批量数据集,每一个小批次作为一个整体进行训练。\n", "- `Batch`:在数据集`create_data`函数中,使用:`batch_size = 32`。支持自行设置,表示将整体数据集划分成小批量数据集,每一个小批次作为一个整体进行训练。\n",
...@@ -397,7 +398,7 @@ ...@@ -397,7 +398,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/001.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/data_map.png)"
] ]
}, },
{ {
......
...@@ -26,38 +26,38 @@ ...@@ -26,38 +26,38 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"1数据集的准备,这里使用的是MNIST数据集。\n", "1. 数据集的准备,这里使用的是MNIST数据集。\n",
"\n", "\n",
"2构建一个网络,这里使用LeNet网络。(此处将使用第二种记录方式`ImageSummary`)。\n", "2. 构建一个网络,这里使用LeNet网络。(此处将使用第二种记录方式`ImageSummary`)。\n",
"\n", "\n",
"3训练网络和测试网络的搭建及运行。(此处将操作`SummaryCollector`初始化,并记录模型训练和模型测试相关信息)。\n", "3. 训练网络和测试网络的搭建及运行。(此处将操作`SummaryCollector`初始化,并记录模型训练和模型测试相关信息)。\n",
"\n", "\n",
"4启动MindInsight服务。\n", "4. 启动MindInsight服务。\n",
"\n", "\n",
"5模型溯源的使用。调整模型参数多次存储数据,并使用MindInsight的模型溯源功能对不同优化参数下训练产生的模型作对比,了解MindSpore中的各类优化对训练过程的影响及如何调优训练过程。\n", "5. 模型溯源的使用。调整模型参数多次存储数据,并使用MindInsight的模型溯源功能对不同优化参数下训练产生的模型作对比,了解MindSpore中的各类优化对训练过程的影响及如何调优训练过程。\n",
"\n", "\n",
"6数据溯源的使用。调整数据参数多次存储数据,并使用MindInsight的数据溯源功能对不同数据集下训练产生的模型进行对比分析,了解如何调优。" "6. 数据溯源的使用。调整数据参数多次存储数据,并使用MindInsight的数据溯源功能对不同数据集下训练产生的模型进行对比分析,了解如何调优。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"本次体验将使用快速入门案例作为基础用例,将MindInsight的模型溯源和数据溯源的数据记录功能加入到案例中,快速入门案例的源码请参考:https://gitee.com/mindspore/docs/blob/r0.5/tutorials/tutorial_code/lenet.py 。" "本次体验将使用快速入门案例作为基础用例,将MindInsight的模型溯源和数据溯源的数据记录功能加入到案例中,快速入门案例的源码请参考:<https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/lenet.py>。"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 一、训练的数据集下载" "## 训练的数据集下载"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、数据集准备" "### 数据集准备"
] ]
}, },
{ {
...@@ -65,8 +65,8 @@ ...@@ -65,8 +65,8 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### 方法一:\n", "#### 方法一:\n",
"从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", "从以下网址下载,并将数据包解压缩后放至Jupyter的工作目录下:<br/>训练数据集:{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
"<br/>测试数据集:{\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}<br/>我们用下面代码查询jupyter的工作目录。" "<br/>测试数据集:{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\", \"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}<br/>我们用下面代码查询jupyter的工作目录。"
] ]
}, },
{ {
...@@ -100,15 +100,14 @@ ...@@ -100,15 +100,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Network request module, data download module, decompression module\n",
"import urllib.request \n", "import urllib.request \n",
"from urllib.parse import urlparse\n", "from urllib.parse import urlparse\n",
"import gzip \n", "import gzip \n",
"\n", "\n",
"def unzipfile(gzip_path):\n", "def unzip_file(gzip_path):\n",
" \"\"\"unzip dataset file\n", " \"\"\"unzip dataset file\n",
" Args:\n", " Args:\n",
" gzip_path: dataset file path\n", " gzip_path (str): Dataset file path\n",
" \"\"\"\n", " \"\"\"\n",
" open_file = open(gzip_path.replace('.gz',''), 'wb')\n", " open_file = open(gzip_path.replace('.gz',''), 'wb')\n",
" gz_file = gzip.GzipFile(gzip_path)\n", " gz_file = gzip.GzipFile(gzip_path)\n",
...@@ -134,7 +133,7 @@ ...@@ -134,7 +133,7 @@
" file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
" \n", " \n",
" for url in test_url:\n", " for url in test_url:\n",
...@@ -143,7 +142,7 @@ ...@@ -143,7 +142,7 @@
" file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
"\n", "\n",
"download_dataset()" "download_dataset()"
...@@ -160,7 +159,7 @@ ...@@ -160,7 +159,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、数据集处理" "### 数据集处理"
] ]
}, },
{ {
...@@ -169,11 +168,11 @@ ...@@ -169,11 +168,11 @@
"source": [ "source": [
"数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。\n", "数据集处理对于训练非常重要,好的数据集可以有效提高训练精度和效率。在加载数据集前,我们通常会对数据集进行一些处理。\n",
"<br/>我们定义一个函数`create_dataset`来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n", "<br/>我们定义一个函数`create_dataset`来创建数据集。在这个函数中,我们定义好需要进行的数据增强和处理操作:\n",
"<br/>1、定义数据集。\n", "1. 定义数据集。\n",
"<br/>2、定义进行数据增强和处理所需要的一些参数。\n", "2. 定义进行数据增强和处理所需要的一些参数。\n",
"<br/>3、根据参数,生成对应的数据增强操作。\n", "3. 根据参数,生成对应的数据增强操作。\n",
"<br/>4、使用`map`映射函数,将数据操作应用到数据集。\n", "4. 使用`map`映射函数,将数据操作应用到数据集。\n",
"<br/>5、对生成的数据集进行处理。" "5. 对生成的数据集进行处理。"
] ]
}, },
{ {
...@@ -199,36 +198,36 @@ ...@@ -199,36 +198,36 @@
" num_parallel_workers=1):\n", " num_parallel_workers=1):\n",
" \"\"\" create dataset for train or test\n", " \"\"\" create dataset for train or test\n",
" Args:\n", " Args:\n",
" data_path: Data path\n", " data_path (str): Data path\n",
" batch_size: The number of data records in each group\n", " batch_size (int): The number of data records in each group\n",
" repeat_size: The number of replicated data records\n", " repeat_size (int): The number of replicated data records\n",
" num_parallel_workers: The number of parallel workers\n", " num_parallel_workers (int): The number of parallel workers\n",
" \"\"\"\n", " \"\"\"\n",
" # define dataset\n", " # define dataset\n",
" mnist_ds = ds.MnistDataset(data_path)\n", " mnist_ds = ds.MnistDataset(data_path)\n",
"\n", "\n",
" # Define some parameters needed for data enhancement and rough justification\n", " # define some parameters needed for data enhancement and rough justification\n",
" resize_height, resize_width = 32, 32\n", " resize_height, resize_width = 32, 32\n",
" rescale = 1.0 / 255.0\n", " rescale = 1.0 / 255.0\n",
" shift = 0.0\n", " shift = 0.0\n",
" rescale_nml = 1 / 0.3081\n", " rescale_nml = 1 / 0.3081\n",
" shift_nml = -1 * 0.1307 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n",
"\n", "\n",
" # According to the parameters, generate the corresponding data enhancement method\n", " # according to the parameters, generate the corresponding data enhancement method\n",
" resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) by bilinear interpolation\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
" rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
" rescale_op = CV.Rescale(rescale, shift) # rescale images\n", " rescale_op = CV.Rescale(rescale, shift)\n",
" hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n", " hwc2chw_op = CV.HWC2CHW()\n",
" type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n", " type_cast_op = C.TypeCast(mstype.int32)\n",
"\n", "\n",
" # Using map() to apply operations to a dataset\n", " # using map method to apply operations to a dataset\n",
" mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
" \n", " \n",
" # Process the generated dataset\n", " # process the generated dataset\n",
" buffer_size = 10000\n", " buffer_size = 10000\n",
" mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
" mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
...@@ -241,7 +240,7 @@ ...@@ -241,7 +240,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 二、构建LeNet5网络" "## 构建LeNet5网络"
] ]
}, },
{ {
...@@ -260,7 +259,7 @@ ...@@ -260,7 +259,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -268,7 +267,6 @@ ...@@ -268,7 +267,6 @@
"import mindspore.nn as nn\n", "import mindspore.nn as nn\n",
"from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.common.initializer import TruncatedNormal\n",
"\n", "\n",
"# Initialize 2D convolution function\n",
"def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
" \"\"\"Conv layer weight initial.\"\"\"\n", " \"\"\"Conv layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
...@@ -276,50 +274,46 @@ ...@@ -276,50 +274,46 @@
" kernel_size=kernel_size, stride=stride, padding=padding,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n",
" weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
"\n", "\n",
"# Initialize full connection layer\n",
"def fc_with_initialize(input_channels, out_channels):\n", "def fc_with_initialize(input_channels, out_channels):\n",
" \"\"\"Fc layer weight initial.\"\"\"\n", " \"\"\"Fc layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
" bias = weight_variable()\n", " bias = weight_variable()\n",
" return nn.Dense(input_channels, out_channels, weight, bias)\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n",
"\n", "\n",
"# Set truncated normal distribution\n",
"def weight_variable():\n", "def weight_variable():\n",
" \"\"\"Weight initial.\"\"\"\n", " \"\"\"Weight initial.\"\"\"\n",
" return TruncatedNormal(0.02)\n", " return TruncatedNormal(0.02)\n",
"\n", "\n",
"class LeNet5(nn.Cell):\n", "class LeNet5(nn.Cell):\n",
" \"\"\"Lenet network structure.\"\"\"\n", " \"\"\"Lenet network structure.\"\"\"\n",
" # define the operator required\n",
" def __init__(self):\n", " def __init__(self):\n",
" super(LeNet5, self).__init__()\n", " super(LeNet5, self).__init__()\n",
" self.batch_size = 32 # 32 pictures in each group\n", " self.batch_size = 32 \n",
" self.conv1 = conv(1, 6, 5) # Convolution layer 1, 1 channel input (1 Figure), 6 channel output (6 figures), convolution core 5 * 5\n", " self.conv1 = conv(1, 6, 5)\n",
" self.conv2 = conv(6, 16, 5) # Convolution layer 2,6-channel input, 16 channel output, convolution kernel 5 * 5\n", " self.conv2 = conv(6, 16, 5)\n",
" self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
" self.fc2 = fc_with_initialize(120, 84)\n", " self.fc2 = fc_with_initialize(120, 84)\n",
" self.fc3 = fc_with_initialize(84, 10)\n", " self.fc3 = fc_with_initialize(84, 10)\n",
" self.relu = nn.ReLU()\n", " self.relu = nn.ReLU()\n",
" self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.flatten = nn.Flatten()\n", " self.flatten = nn.Flatten()\n",
" #Init ImageSummary\n", " # Init ImageSummary\n",
" self.sm_image = P.ImageSummary()\n", " self.sm_image = P.ImageSummary()\n",
"\n", "\n",
" # use the preceding operators to construct networks\n",
" def construct(self, x):\n", " def construct(self, x):\n",
" self.sm_image(\"image\",x)\n", " self.sm_image(\"image\",x)\n",
" x = self.conv1(x) # 1*32*32-->6*28*28\n", " x = self.conv1(x)\n",
" x = self.relu(x) # 6*28*28-->6*14*14\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.conv2(x) # Convolution layer\n", " x = self.conv2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.flatten(x) # Dimensionality reduction\n", " x = self.flatten(x)\n",
" x = self.fc1(x) # Full connection\n", " x = self.fc1(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc2(x) # Full connection\n", " x = self.fc2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc3(x) # Full connection\n", " x = self.fc3(x)\n",
" return x" " return x"
] ]
}, },
...@@ -327,14 +321,14 @@ ...@@ -327,14 +321,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 三、训练网络和测试网络构建" "## 训练网络和测试网络构建"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、使用SummaryCollector放入到训练网络中记录训练数据" "### 使用SummaryCollector放入到训练网络中记录训练数据"
] ]
}, },
{ {
...@@ -350,12 +344,10 @@ ...@@ -350,12 +344,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Training and testing related modules\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"import os\n", "import os\n",
"\n", "\n",
"\n",
"def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n", "def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n",
" \"\"\"Define the training method.\"\"\"\n", " \"\"\"Define the training method.\"\"\"\n",
" print(\"============== Starting Training ==============\")\n", " print(\"============== Starting Training ==============\")\n",
...@@ -368,7 +360,7 @@ ...@@ -368,7 +360,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、使用SummaryCollector放入到测试网络中记录测试数据" "### 使用SummaryCollector放入到测试网络中记录测试数据"
] ]
}, },
{ {
...@@ -380,7 +372,7 @@ ...@@ -380,7 +372,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -403,14 +395,14 @@ ...@@ -403,14 +395,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3、主程序运行入口" "### 主程序运行入口"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"初始化`SummaryCollector`,使用`collect_specified_data`控制需要记录的数据,我们这里只需要记录模型溯源和数据溯源,所以将`collect_train_lineage`和`collect_eval_lineage`参数设置成`True`,其他的参数使用`keep_default_action`设置成`False`,SummaryCollector能够记录哪些数据,请参考官网:<https://www.mindspore.cn/api/zh-CN/r0.5/api/python/mindspore/mindspore.train.html#mindspore.train.callback.SummaryCollector> 。" "初始化`SummaryCollector`,使用`collect_specified_data`控制需要记录的数据,我们这里只需要记录模型溯源和数据溯源,所以将`collect_train_lineage`和`collect_eval_lineage`参数设置成`True`,其他的参数使用`keep_default_action`设置成`False`,SummaryCollector能够记录哪些数据,请参考官网:<https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.train.html?highlight=collector#mindspore.train.callback.SummaryCollector>。"
] ]
}, },
{ {
...@@ -427,7 +419,7 @@ ...@@ -427,7 +419,7 @@
"\n", "\n",
"if __name__==\"__main__\":\n", "if __name__==\"__main__\":\n",
" context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n", " context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n",
" lr = 0.01 # learning rate\n", " lr = 0.01\n",
" momentum = 0.9 \n", " momentum = 0.9 \n",
" epoch_size = 3\n", " epoch_size = 3\n",
" mnist_path = \"./MNIST_Data\"\n", " mnist_path = \"./MNIST_Data\"\n",
...@@ -453,14 +445,14 @@ ...@@ -453,14 +445,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 四、启动及关闭MindInsight服务" "## 启动及关闭MindInsight服务"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"这里主要展示如何启用及关闭MindInsight,更多的命令集信息,请参考MindSpore官方网站:https://www.mindspore.cn/tutorial/zh-CN/r0.5/advanced_use/visualization_tutorials.html 。" "这里主要展示如何启用及关闭MindInsight,更多的命令集信息,请参考MindSpore官方网站:<https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/visualization_tutorials.html>。"
] ]
}, },
{ {
...@@ -489,7 +481,7 @@ ...@@ -489,7 +481,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"查询是否启动成功,在网址输入:`127.0.0.1:8090`,如果看到如下界面说明启动成功。" "查询是否启动成功,在网址输入`127.0.0.1:8090`,如果看到如下界面说明启动成功。"
] ]
}, },
{ {
...@@ -514,14 +506,14 @@ ...@@ -514,14 +506,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 五、模型溯源" "## 模型溯源"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、连接到模型溯源地址" "### 连接到模型溯源地址"
] ]
}, },
{ {
...@@ -575,7 +567,7 @@ ...@@ -575,7 +567,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、观察分析记录下来的溯源参数" "### 观察分析记录下来的溯源参数"
] ]
}, },
{ {
...@@ -604,21 +596,21 @@ ...@@ -604,21 +596,21 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## 六、数据溯源" "## 数据溯源"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1、连接到数据溯源地址" "### 连接到数据溯源地址"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"浏览器中输入:127.0.0.1:8090连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:" "浏览器中输入:`127.0.0.1:8090`连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:"
] ]
}, },
{ {
...@@ -654,7 +646,7 @@ ...@@ -654,7 +646,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2、观察分析数据溯源参数" "### 观察分析数据溯源参数"
] ]
}, },
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册