提交 b0985fe9 编写于 作者: L lvmingfu

Unify code format in notebook

上级 4529029c
...@@ -100,15 +100,14 @@ ...@@ -100,15 +100,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Network request module, data download module, decompression module\n",
"import urllib.request \n", "import urllib.request \n",
"from urllib.parse import urlparse\n", "from urllib.parse import urlparse\n",
"import gzip \n", "import gzip \n",
"\n", "\n",
"def unzipfile(gzip_path):\n", "def unzip_file(gzip_path):\n",
" \"\"\"unzip dataset file\n", " \"\"\"unzip dataset file\n",
" Args:\n", " Args:\n",
" gzip_path: dataset file path\n", " gzip_path (str): Dataset file path\n",
" \"\"\"\n", " \"\"\"\n",
" open_file = open(gzip_path.replace('.gz',''), 'wb')\n", " open_file = open(gzip_path.replace('.gz',''), 'wb')\n",
" gz_file = gzip.GzipFile(gzip_path)\n", " gz_file = gzip.GzipFile(gzip_path)\n",
...@@ -134,7 +133,7 @@ ...@@ -134,7 +133,7 @@
" file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(train_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
" \n", " \n",
" for url in test_url:\n", " for url in test_url:\n",
...@@ -143,7 +142,7 @@ ...@@ -143,7 +142,7 @@
" file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n", " file_name = os.path.join(test_path,url_parse.path.split('/')[-1])\n",
" if not os.path.exists(file_name.replace('.gz', '')):\n", " if not os.path.exists(file_name.replace('.gz', '')):\n",
" file = urllib.request.urlretrieve(url, file_name)\n", " file = urllib.request.urlretrieve(url, file_name)\n",
" unzipfile(file_name)\n", " unzip_file(file_name)\n",
" os.remove(file_name)\n", " os.remove(file_name)\n",
"\n", "\n",
"download_dataset()" "download_dataset()"
...@@ -199,36 +198,36 @@ ...@@ -199,36 +198,36 @@
" num_parallel_workers=1):\n", " num_parallel_workers=1):\n",
" \"\"\" create dataset for train or test\n", " \"\"\" create dataset for train or test\n",
" Args:\n", " Args:\n",
" data_path: Data path\n", " data_path (str): Data path\n",
" batch_size: The number of data records in each group\n", " batch_size (int): The number of data records in each group\n",
" repeat_size: The number of replicated data records\n", " repeat_size (int): The number of replicated data records\n",
" num_parallel_workers: The number of parallel workers\n", " num_parallel_workers (int): The number of parallel workers\n",
" \"\"\"\n", " \"\"\"\n",
" # define dataset\n", " # define dataset\n",
" mnist_ds = ds.MnistDataset(data_path)\n", " mnist_ds = ds.MnistDataset(data_path)\n",
"\n", "\n",
" # Define some parameters needed for data enhancement and rough justification\n", " # define some parameters needed for data enhancement and rough justification\n",
" resize_height, resize_width = 32, 32\n", " resize_height, resize_width = 32, 32\n",
" rescale = 1.0 / 255.0\n", " rescale = 1.0 / 255.0\n",
" shift = 0.0\n", " shift = 0.0\n",
" rescale_nml = 1 / 0.3081\n", " rescale_nml = 1 / 0.3081\n",
" shift_nml = -1 * 0.1307 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n",
"\n", "\n",
" # According to the parameters, generate the corresponding data enhancement method\n", " # according to the parameters, generate the corresponding data enhancement method\n",
" resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) by bilinear interpolation\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
" rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
" rescale_op = CV.Rescale(rescale, shift) # rescale images\n", " rescale_op = CV.Rescale(rescale, shift)\n",
" hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n", " hwc2chw_op = CV.HWC2CHW()\n",
" type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n", " type_cast_op = C.TypeCast(mstype.int32)\n",
"\n", "\n",
" # Using map() to apply operations to a dataset\n", " # using map method to apply operations to a dataset\n",
" mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
" mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
" \n", " \n",
" # Process the generated dataset\n", " # process the generated dataset\n",
" buffer_size = 10000\n", " buffer_size = 10000\n",
" mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n",
" mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
...@@ -268,7 +267,6 @@ ...@@ -268,7 +267,6 @@
"import mindspore.nn as nn\n", "import mindspore.nn as nn\n",
"from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.common.initializer import TruncatedNormal\n",
"\n", "\n",
"# Initialize 2D convolution function\n",
"def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
" \"\"\"Conv layer weight initial.\"\"\"\n", " \"\"\"Conv layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
...@@ -276,50 +274,46 @@ ...@@ -276,50 +274,46 @@
" kernel_size=kernel_size, stride=stride, padding=padding,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n",
" weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
"\n", "\n",
"# Initialize full connection layer\n",
"def fc_with_initialize(input_channels, out_channels):\n", "def fc_with_initialize(input_channels, out_channels):\n",
" \"\"\"Fc layer weight initial.\"\"\"\n", " \"\"\"Fc layer weight initial.\"\"\"\n",
" weight = weight_variable()\n", " weight = weight_variable()\n",
" bias = weight_variable()\n", " bias = weight_variable()\n",
" return nn.Dense(input_channels, out_channels, weight, bias)\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n",
"\n", "\n",
"# Set truncated normal distribution\n",
"def weight_variable():\n", "def weight_variable():\n",
" \"\"\"Weight initial.\"\"\"\n", " \"\"\"Weight initial.\"\"\"\n",
" return TruncatedNormal(0.02)\n", " return TruncatedNormal(0.02)\n",
"\n", "\n",
"class LeNet5(nn.Cell):\n", "class LeNet5(nn.Cell):\n",
" \"\"\"Lenet network structure.\"\"\"\n", " \"\"\"Lenet network structure.\"\"\"\n",
" # define the operator required\n",
" def __init__(self):\n", " def __init__(self):\n",
" super(LeNet5, self).__init__()\n", " super(LeNet5, self).__init__()\n",
" self.batch_size = 32 # 32 pictures in each group\n", " self.batch_size = 32 \n",
" self.conv1 = conv(1, 6, 5) # Convolution layer 1, 1 channel input (1 Figure), 6 channel output (6 figures), convolution core 5 * 5\n", " self.conv1 = conv(1, 6, 5)\n",
" self.conv2 = conv(6, 16, 5) # Convolution layer 2,6-channel input, 16 channel output, convolution kernel 5 * 5\n", " self.conv2 = conv(6, 16, 5)\n",
" self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
" self.fc2 = fc_with_initialize(120, 84)\n", " self.fc2 = fc_with_initialize(120, 84)\n",
" self.fc3 = fc_with_initialize(84, 10)\n", " self.fc3 = fc_with_initialize(84, 10)\n",
" self.relu = nn.ReLU()\n", " self.relu = nn.ReLU()\n",
" self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.flatten = nn.Flatten()\n", " self.flatten = nn.Flatten()\n",
" #Init ImageSummary\n", " # Init ImageSummary\n",
" self.sm_image = P.ImageSummary()\n", " self.sm_image = P.ImageSummary()\n",
"\n", "\n",
" # use the preceding operators to construct networks\n",
" def construct(self, x):\n", " def construct(self, x):\n",
" self.sm_image(\"image\",x)\n", " self.sm_image(\"image\",x)\n",
" x = self.conv1(x) # 1*32*32-->6*28*28\n", " x = self.conv1(x)\n",
" x = self.relu(x) # 6*28*28-->6*14*14\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.conv2(x) # Convolution layer\n", " x = self.conv2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.max_pool2d(x) # Pool layer\n", " x = self.max_pool2d(x)\n",
" x = self.flatten(x) # Dimensionality reduction\n", " x = self.flatten(x)\n",
" x = self.fc1(x) # Full connection\n", " x = self.fc1(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc2(x) # Full connection\n", " x = self.fc2(x)\n",
" x = self.relu(x) # Function excitation layer\n", " x = self.relu(x)\n",
" x = self.fc3(x) # Full connection\n", " x = self.fc3(x)\n",
" return x" " return x"
] ]
}, },
...@@ -350,12 +344,10 @@ ...@@ -350,12 +344,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Training and testing related modules\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector, Callback\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"import os\n", "import os\n",
"\n", "\n",
"\n",
"def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n", "def train_net(model, epoch_size, mnist_path, repeat_size, ckpoint_cb, summary_collector):\n",
" \"\"\"Define the training method.\"\"\"\n", " \"\"\"Define the training method.\"\"\"\n",
" print(\"============== Starting Training ==============\")\n", " print(\"============== Starting Training ==============\")\n",
...@@ -427,7 +419,7 @@ ...@@ -427,7 +419,7 @@
"\n", "\n",
"if __name__==\"__main__\":\n", "if __name__==\"__main__\":\n",
" context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n", " context.set_context(mode=context.GRAPH_MODE, device_target = \"GPU\")\n",
" lr = 0.01 # learning rate\n", " lr = 0.01\n",
" momentum = 0.9 \n", " momentum = 0.9 \n",
" epoch_size = 3\n", " epoch_size = 3\n",
" mnist_path = \"./MNIST_Data\"\n", " mnist_path = \"./MNIST_Data\"\n",
...@@ -618,7 +610,7 @@ ...@@ -618,7 +610,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"浏览器中输入:127.0.0.1:8090连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:" "浏览器中输入:`127.0.0.1:8090`连接上MindInsight的服务,点击模型溯源,如下图数据溯源界面:"
] ]
}, },
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册