提交 ac6f99e2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!746 Fix the mindinsight notebook tutorial

Merge pull request !746 from ougongchang/fix_notebook
...@@ -74,51 +74,55 @@ ...@@ -74,51 +74,55 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import urllib.request \n",
"from urllib.parse import urlparse\n",
"import gzip \n",
"import os\n", "import os\n",
"import gzip\n",
"import urllib.request\n",
"from urllib.parse import urlparse\n",
"\n",
"\n", "\n",
"def unzip_file(gzip_path):\n", "def unzip_file(gzip_path):\n",
" \"\"\"unzip dataset file\n", " \"\"\"\n",
" Unzip a given gzip file.\n",
"\n",
" Args:\n", " Args:\n",
" gzip_path: dataset file path\n", " gzip_path (str): The gzip file path\n",
" \"\"\"\n", " \"\"\"\n",
" open_file = open(gzip_path.replace('.gz',''), 'wb')\n", " open_file = open(gzip_path.replace('.gz', ''), 'wb')\n",
" gz_file = gzip.GzipFile(gzip_path)\n", " gz_file = gzip.GzipFile(gzip_path)\n",
" open_file.write(gz_file.read())\n", " open_file.write(gz_file.read())\n",
" gz_file.close()\n", " gz_file.close()\n",
" \n", "\n",
"\n",
"def download_dataset():\n", "def download_dataset():\n",
" \"\"\"Download the dataset from http://yann.lecun.com/exdb/mnist/.\"\"\"\n", " \"\"\"Download the dataset from http://yann.lecun.com/exdb/mnist/.\"\"\"\n",
" print(\"******Downloading the MNIST dataset******\")\n", " print(\"******Downloading the MNIST dataset******\")\n",
" train_path = \"./MNIST_Data/train/\" \n", " train_path = \"./MNIST_Data/train/\"\n",
" test_path = \"./MNIST_Data/test/\"\n", " test_path = \"./MNIST_Data/test/\"\n",
" train_path_check = os.path.exists(train_path)\n", " train_path_check = os.path.exists(train_path)\n",
" test_path_check = os.path.exists(test_path)\n", " test_path_check = os.path.exists(test_path)\n",
" if train_path_check == False and test_path_check == False:\n", " if not train_path_check and not test_path_check:\n",
" os.makedirs(train_path)\n", " os.makedirs(train_path)\n",
" os.makedirs(test_path)\n", " os.makedirs(test_path)\n",
" train_url = {\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n", " train_url = {\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\",\n",
" test_url = {\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\", \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n", " \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n",
" \n", " test_url = {\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\",\n",
" \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n",
"\n",
" for url in train_url:\n", " for url in train_url:\n",
" url_parse = urlparse(url)\n", " url_parse = urlparse(url)\n",
" \"\"\"split the file name from url\"\"\"\n", " # split the file name from url\n",
" file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(train_path, url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')) and not os.path.exists(file_name):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", "\n",
" \n",
" for url in test_url:\n", " for url in test_url:\n",
" url_parse = urlparse(url)\n", " url_parse = urlparse(url)\n",
" \"\"\"split the file name from url\"\"\"\n", " # split the file name from url\n",
" file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(test_path, url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')) and not os.path.exists(file_name):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n",
"\n", "\n",
"download_dataset()" "download_dataset()"
] ]
...@@ -127,9 +131,8 @@ ...@@ -127,9 +131,8 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### 数据集使用\n", "#### 数据增强\n",
"\n", "对数据集进行数据增强操作,可以提升模型精度。\n"
"设置正确的数据存放路径,可将数据集读取出来,并对整体数据集做预处理,让数据更能发挥模型性能。MindInsight可视化的数据图,便是显示的数据集预处理时的变化方式和顺序。"
] ]
}, },
{ {
...@@ -148,32 +151,39 @@ ...@@ -148,32 +151,39 @@
"def create_dataset(data_path, batch_size=32, repeat_size=1,\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
" num_parallel_workers=1):\n", " num_parallel_workers=1):\n",
" \"\"\"\n", " \"\"\"\n",
" create dataset for train or test\n", " Create dataset for train or test.\n",
"\n",
" Args:\n",
" data_path (str): The absolute path of the dataset\n",
" batch_size (int): The number of data records in each group\n",
" repeat_size (int): The number of replicated data records\n",
" num_parallel_workers (int): The number of parallel workers\n",
" \"\"\"\n", " \"\"\"\n",
" \"\"\"define dataset\"\"\"\n", " # define dataset\n",
" mnist_ds = ds.MnistDataset(data_path)\n", " mnist_ds = ds.MnistDataset(data_path)\n",
"\n", "\n",
" # define some parameters needed for data enhancement and rough justification\n",
" resize_height, resize_width = 32, 32\n", " resize_height, resize_width = 32, 32\n",
" rescale = 1.0 / 255.0\n", " rescale = 1.0 / 255.0\n",
" shift = 0.0\n", " shift = 0.0\n",
" rescale_nml = 1 / 0.3081\n", " rescale_nml = 1 / 0.3081\n",
" shift_nml = -1 * 0.1307 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n",
"\n", "\n",
" \"\"\"define map operations\"\"\"\n", " # according to the parameters, generate the corresponding data enhancement method\n",
" type_cast_op = C.TypeCast(mstype.int32)\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
" resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode\n",
" rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
" rescale_op = CV.Rescale(rescale, shift)\n", " rescale_op = CV.Rescale(rescale, shift)\n",
" hwc2chw_op = CV.HWC2CHW()\n", " hwc2chw_op = CV.HWC2CHW()\n",
" type_cast_op = C.TypeCast(mstype.int32)\n",
"\n", "\n",
" \"\"\"apply map operations on images\"\"\"\n", " # using map method to apply operations to a dataset\n",
" mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
"\n", " \n",
" \"\"\"apply DatasetOps\"\"\"\n", " # process the generated dataset\n",
" buffer_size = 10000\n", " buffer_size = 10000\n",
" mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
" mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
...@@ -272,15 +282,13 @@ ...@@ -272,15 +282,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### 主程序运行\n", "#### 执行训练\n",
"\n", "\n",
"1. 首先在主函数之前调用所需要的模块,并在主函数之前使用相应接口。\n", "1. 导入所需的代码包,并示例化训练网络。\n",
"2. 通过MindSpore提供的 `SummaryCollector` 接口,实现收集计算图和数据图。在实例化 `SummaryCollector` 时,在 `collect_specified_data` 参数中,通过设置 `collect_graph` 指定收集计算图,设置 `collect_dataset_graph` 指定收集数据图。\n",
"\n", "\n",
"2. 本次体验主要完成计算图与数据图的可视化,定义变量`specified={'collect_graph': True,'collect_dataset_graph': True}`,在`specified`字典中,键名`collect_graph`值设置为`True`,表示记录计算图;键名`collect_dataset_graph`值设置为`True`,表示记录数据图。\n", "更多 `SummaryCollector` 的用法,请点击[API文档](https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.train.html?highlight=summarycollector#mindspore.train.callback.SummaryCollector)查看。\n",
"\n", "\n"
"3. 定义完`specified`变量后,传参到`summary_collector`中,最后将`summary_collector`传参到`model`中。\n",
"\n",
"至此,模型中就有了计算图与数据图的可视化功能。"
] ]
}, },
{ {
...@@ -293,9 +301,7 @@ ...@@ -293,9 +301,7 @@
"from mindspore import context\n", "from mindspore import context\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"from mindspore.nn.metrics import Accuracy\n", "from mindspore.nn.metrics import Accuracy\n",
"from mindspore.train.callback import SummaryCollector\n", "from mindspore.train.callback import LossMonitor, SummaryCollector\n",
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor\n",
"\n", "\n",
"if __name__ == \"__main__\":\n", "if __name__ == \"__main__\":\n",
" device_target = \"CPU\"\n", " device_target = \"CPU\"\n",
...@@ -308,18 +314,15 @@ ...@@ -308,18 +314,15 @@
" net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction=\"mean\")\n", " net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction=\"mean\")\n",
" net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", " net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
" time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())\n", " time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())\n",
" config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", config=config_ck)\n",
" model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n", " model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
" specified={'collect_graph': True,'collect_dataset_graph': True}\n", "\n",
" specified={'collect_graph': True, 'collect_dataset_graph': True}\n",
" summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified, collect_freq=1, keep_default_action=False)\n", " summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified, collect_freq=1, keep_default_action=False)\n",
" \n", " \n",
" print(\"============== Starting Training ==============\")\n", " print(\"============== Starting Training ==============\")\n",
" model.train(epoch=2, train_dataset=ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor(), summary_collector], dataset_sink_mode=False)\n", " model.train(epoch=2, train_dataset=ds_train, callbacks=[LossMonitor(), summary_collector], dataset_sink_mode=False)\n",
"\n", "\n",
" print(\"============== Starting Testing ==============\")\n", " print(\"============== Starting Testing ==============\")\n",
" param_dict = load_checkpoint(\"checkpoint_lenet-3_1875.ckpt\")\n",
" load_param_into_net(network, param_dict)\n",
" ds_eval = create_dataset(\"./MNIST_Data/test/\")\n", " ds_eval = create_dataset(\"./MNIST_Data/test/\")\n",
" acc = model.eval(ds_eval, dataset_sink_mode=False)\n", " acc = model.eval(ds_eval, dataset_sink_mode=False)\n",
" print(\"============== {} ==============\".format(acc))" " print(\"============== {} ==============\".format(acc))"
...@@ -333,6 +336,8 @@ ...@@ -333,6 +336,8 @@
"- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n", "- 启动MindInsigh服务命令:`mindinsigh start --summary-base-dir=/path/ --port=8080`;\n",
"- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n", "- 执行完服务命令后,访问给出的地址,查看MindInsigh可视化结果。\n",
"\n", "\n",
"> 其中 /path/ 为 `SummaryCollector` 中参数 `summary_dir` 所指定的目录。\n",
"\n",
"![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/mindinsight_map.png)" "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/mindinsight_map.png)"
] ]
}, },
...@@ -354,45 +359,25 @@ ...@@ -354,45 +359,25 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 数据图信息\n", "### 数据图展示\n",
"\n", "\n",
"数据图所展示的顺序与数据集使用处代码顺序对应\n", "数据图展示了数据增强中对数据进行操作的流程。\n",
"\n", "\n",
"1. 首先是从加载数据集`mnist_ds = ds.MnistDataset(data_path)`开始,对应数据图中`MnistDataset`。\n", "1. 首先是从加载数据集 `mnist_ds = ds.MnistDataset(data_path)` 开始,对应数据图中 `MnistDataset`。\n",
"\n", "\n",
"2. 在以下所示代码中,是数据预处理的一些方法,顺序与数据图中所示顺序对应。\n", "2. 下面代码为上面的 `create_dataset` 函数中作数据预处理与数据增强的相关操作。可以从数据图中清晰地看到数据处理的流程。通过查看数据图,可以帮助分析是否存在不恰当的数据处理流程。\n",
"\n", "\n",
"```\n", "```\n",
"type_cast_op = C.TypeCast(mstype.int32)\n",
"resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
"rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
"rescale_op = CV.Rescale(rescale, shift)\n",
"hwc2chw_op = CV.HWC2CHW()\n",
"mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
"mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
"```\n",
"\n",
"- `TypeCast`:在数据集`create_data`函数中,使用:`TypeCase(mstype.int32)`,将数据类型转换成我们所设置的类型。\n",
"- `Resize`:在数据集`create_data`函数中,使用:`Resize(resize_height,resize_width = 32,32)`,可以将数据的高和宽做调整。\n",
"- `Rescale`:在数据集`create_data`函数中,使用:`rescale = 1.0 / 255.0`;`Rescale(rescale,shift)`,可以重新数据格式。\n",
"- `HWC2CHW`:在数据集`create_data`函数中,使用:`HWC2CHW()`,此方法可以将数据所带信息与通道结合,一并加载。\n",
"\n",
"\n",
"3. 前面的几个步骤是数据集的预处理顺序,后面几个步骤是模型加载数据集时要定义的参数,顺序与数据图中对应。\n",
"\n", "\n",
"```\n",
"buffer_size = 10000\n",
"mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", "mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
"mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", "mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
"mnist_ds = mnist_ds.repeat(repeat_size)\n", "mnist_ds = mnist_ds.repeat(repeat_size)\n",
"```\n", "```\n"
" \n",
"- `Shuffle`:在数据集`create_data`函数中,使用:`buffer_size = 10000`,后面数值可以支持自行设置,表示一次缓存数据的数量。\n",
"- `Batch`:在数据集`create_data`函数中,使用:`batch_size = 32`。支持自行设置,表示将整体数据集划分成小批量数据集,每一个小批次作为一个整体进行训练。\n",
"- `Repeat`:在数据集`create_data`函数中,使用:`repeat_size = 1`,支持自行设定,表示的是一次运行中要训练的次数。"
] ]
}, },
{ {
...@@ -408,7 +393,7 @@ ...@@ -408,7 +393,7 @@
"source": [ "source": [
"### 关闭MindInsight\n", "### 关闭MindInsight\n",
"\n", "\n",
"- 查看完成后,在命令行中可执行此命令`mindinsight stop --port=8080`,关闭MindInsight。" "- 查看完成后,在命令行中可执行此命令 `mindinsight stop --port=8080`,关闭MindInsight。"
] ]
} }
], ],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册