# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Lenet Tutorial The sample can be run on CPU, GPU and Ascend 910 AI processor. """ import os import urllib.request from urllib.parse import urlparse import gzip import argparse import mindspore.dataset as ds import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train import Model from mindspore.common.initializer import TruncatedNormal import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.transforms.vision import Inter from mindspore.nn.metrics import Accuracy from mindspore.common import dtype as mstype from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits def unzipfile(gzip_path): """unzip dataset file Args: gzip_path: dataset file path """ open_file = open(gzip_path.replace('.gz',''), 'wb') gz_file = gzip.GzipFile(gzip_path) open_file.write(gz_file.read()) gz_file.close() def download_dataset(): """Download the dataset from http://yann.lecun.com/exdb/mnist/.""" print("******Downloading the MNIST dataset******") train_path = "./MNIST_Data/train/" test_path = "./MNIST_Data/test/" train_path_check = os.path.exists(train_path) test_path_check = os.path.exists(test_path) if train_path_check == False and test_path_check ==False: os.makedirs(train_path) os.makedirs(test_path) train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"} test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"} for url in train_url: url_parse = urlparse(url) # split the file name from url file_name = os.path.join(train_path,url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz','')): file = urllib.request.urlretrieve(url, file_name) unzipfile(file_name) os.remove(file_name) for url in test_url: url_parse = urlparse(url) # split the file name from url file_name = os.path.join(test_path,url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz','')): file = urllib.request.urlretrieve(url, file_name) unzipfile(file_name) os.remove(file_name) def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path: Data path batch_size: The number of data records in each group repeat_size: The number of replicated data records num_parallel_workers: The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define operation parameters resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images rescale_op = CV.Rescale(rescale, shift) # rescale images hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network # apply map operations on images mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): """Conv layer weight initial.""" weight = weight_variable() return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="valid") def fc_with_initialize(input_channels, out_channels): """Fc layer weight initial.""" weight = weight_variable() bias = weight_variable() return nn.Dense(input_channels, out_channels, weight, bias) def weight_variable(): """Weight initial.""" return TruncatedNormal(0.02) class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self): super(LeNet5, self).__init__() self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc2 = fc_with_initialize(120, 84) self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): """Define the training method.""" print("============== Starting Training ==============") # load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) def test_net(args, network, model, mnist_path): """Define the evaluation method.""" print("============== Starting Testing ==============") # load the saved model for evaluation param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt") # load parameter to the network load_param_into_net(network, param_dict) # load testing dataset ds_eval = create_dataset(os.path.join(mnist_path, "test")) acc = model.eval(ds_eval, dataset_sink_mode=False) print("============== Accuracy:{} ==============".format(acc)) if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) # download mnist dataset download_dataset() # learning rate setting lr = 0.01 momentum = 0.9 epoch_size = 1 mnist_path = "./MNIST_Data" # define the loss function net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') repeat_size = epoch_size # create the network network = LeNet5() # define the optimizer net_opt = nn.Momentum(network.trainable_params(), lr, momentum) config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) # save the network model and parameters for subsequence fine-tuning ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) # group layers into an object with training and evaluation features model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) test_net(args, network, model, mnist_path)