提交 1314bdfd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!21 upgrade lenet experiment to r0.5, unify codes for different platf orm

Merge pull request !21 from dyonghan/update_to_0.5
...@@ -129,14 +129,15 @@ dmypy.json ...@@ -129,14 +129,15 @@ dmypy.json
.pyre/ .pyre/
# MindSpore files # MindSpore files
.dat *.dat
.ir *.ir
.meta *.meta
.ckpt *.ckpt
*.pb
# system files # system files
.DS_Store .DS_Store
.swap *.swap
# IDE # IDE
.idea/ .idea/
......
...@@ -13,7 +13,6 @@ import mindspore.context as context ...@@ -13,7 +13,6 @@ import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore.dataset.transforms.vision import Inter
from mindspore import nn, Tensor from mindspore import nn, Tensor
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
...@@ -21,120 +20,140 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net ...@@ -21,120 +20,140 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
import logging; logging.getLogger('matplotlib.font_manager').disabled = True import logging; logging.getLogger('matplotlib.font_manager').disabled = True
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU
DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32), repeat=1,
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) data_train = os.path.join(data_dir, 'train') # 训练集信息
data_test = os.path.join(data_dir, 'test') # 测试集信息
# define map operations ds = ms.dataset.MnistDataset(data_train if training else data_test)
resize_op = CV.Resize(resize)
rescale_op = CV.Rescale(rescale, shift) ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
hwc2chw_op = CV.HWC2CHW() ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(repeat)
# apply map operations on images
ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(num_epoch)
return ds return ds
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc1 = nn.Dense(400, 120) self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10) self.fc3 = nn.Dense(84, 10)
def construct(self, input_x): def construct(self, x):
output = self.conv1(input_x) x = self.relu(self.conv1(x))
output = self.relu(output) x = self.pool(x)
output = self.pool(output) x = self.relu(self.conv2(x))
output = self.conv2(output) x = self.pool(x)
output = self.relu(output) x = self.flatten(x)
output = self.pool(output) x = self.fc1(x)
output = self.flatten(output) x = self.fc2(x)
output = self.fc1(output) x = self.fc3(x)
output = self.fc2(output)
output = self.fc3(output) return x
return output
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=2, ckpt_name="lenet"):
dataset_sink = context.get_context('device_target') == 'Ascend'
def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"): repeat = num_epochs if dataset_sink else 1
ds_train = create_dataset(num_epoch=num_epoch) ds_train = create_dataset(data_dir, repeat=repeat)
ds_eval = create_dataset(training=False) ds_eval = create_dataset(data_dir, training=False)
steps_per_epoch = ds_train.get_dataset_size() steps_per_epoch = ds_train.get_dataset_size()
net = LeNet5() net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum) opt = nn.Momentum(net.trainable_params(), lr, momentum)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg) ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg)
loss_cb = LossMonitor(steps_per_epoch) loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'}) model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True) model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink)
metrics = model.eval(ds_eval) metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink)
print('Metrics:', metrics) print('Metrics:', metrics)
CKPT = 'b_lenet-2_1875.ckpt' CKPT_1 = 'ckpt/lenet-2_1875.ckpt'
def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"): def resume_train(data_dir, lr=0.001, momentum=0.9, num_epochs=2, ckpt_name="lenet"):
ds_train = create_dataset(num_epoch=num_epoch) dataset_sink = context.get_context('device_target') == 'Ascend'
ds_eval = create_dataset(training=False) repeat = num_epochs if dataset_sink else 1
ds_train = create_dataset(data_dir, repeat=repeat)
ds_eval = create_dataset(data_dir, training=False)
steps_per_epoch = ds_train.get_dataset_size() steps_per_epoch = ds_train.get_dataset_size()
net = LeNet5() net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum) opt = nn.Momentum(net.trainable_params(), lr, momentum)
param_dict = load_checkpoint(CKPT) param_dict = load_checkpoint(CKPT_1)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
load_param_into_net(opt, param_dict) load_param_into_net(opt, param_dict)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg) ckpt_cb = ModelCheckpoint(prefix=ckpt_name, directory='ckpt', config=ckpt_cfg)
loss_cb = LossMonitor(steps_per_epoch) loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'}) model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb]) model.train(num_epochs, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=dataset_sink)
metrics = model.eval(ds_eval) metrics = model.eval(ds_eval, dataset_sink_mode=dataset_sink)
print('Metrics:', metrics) print('Metrics:', metrics)
CKPT_2 = 'ckpt/lenet_1-2_1875.ckpt'
def infer(data_dir):
ds = create_dataset(data_dir, training=False).create_dict_iterator()
data = ds.get_next()
images = data['image']
labels = data['label']
net = LeNet5()
load_checkpoint(CKPT_2, net=net)
model = Model(net)
output = model.predict(Tensor(data['image']))
preds = np.argmax(output.asnumpy(), axis=1)
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(np.squeeze(images[i]))
color = 'blue' if preds[i] == labels[i] else 'red'
plt.title("prediction: {}, truth: {}".format(preds[i], labels[i]), color=color)
plt.xticks([])
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data_url', required=True, default=None, help='Location of data.') parser.add_argument('--data_url', required=False, default='MNIST', help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
import moxing as mox if args.data_url.startswith('s3'):
mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') import moxing
moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST')
os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件 args.data_url = 'MNIST'
test_train() # 请先删除旧的checkpoint目录`ckpt`
print('\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')]))) train(args.data_url)
print('Checkpoints after first training:')
resume_train() print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')])))
print('\n'.join(sorted([x for x in os.listdir('.') if x.startswith('b_lenet')])))
resume_train(args.data_url)
\ No newline at end of file print('Checkpoints after resuming training:')
print('\n'.join(sorted([x for x in os.listdir('ckpt') if x.startswith('lenet')])))
infer(args.data_url)
if args.data_url.startswith('s3'):
import moxing
# 将ckpt目录拷贝至OBS后,可在OBS的`args.train_url`目录下看到ckpt目录
moxing.file.copy_parallel(src_url='ckpt', dst_url=os.path.join(args.data_url, 'ckpt'))
此差异已折叠。
此差异已折叠。
# 在Windows上运行LeNet_MNIST
## 实验介绍
LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在Windows环境下MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。
## 实验目的
- 了解如何使用MindSpore进行简单卷积神经网络的开发。
- 了解如何使用MindSpore进行简单图片分类任务的训练。
- 了解如何使用MindSpore进行简单图片分类任务的验证。
## 预备知识
- 熟练使用Python,了解Shell及Linux操作系统基本知识。
- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。
- 了解并熟悉MindSpore AI计算框架,MindSpore官网:[https://www.mindspore.cn](https://www.mindspore.cn/)
## 实验环境
- Windows-x64版本MindSpore 0.3.0;安装命令可见官网:
[https://www.mindspore.cn/install](https://www.mindspore.cn/install)(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套)。
## 实验准备
### 创建目录
创建一个experiment文件夹,用于存放实验所需的文件代码等。
### 数据集准备
MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)
从MNIST官网下载如下4个文件到本地并解压:
```
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
```
### 脚本准备
[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。
### 准备文件
将脚本和数据集放到到experiment文件夹中,组织为如下形式:
```
experiment
├── MNIST
│ ├── test
│ │ ├── t10k-images-idx3-ubyte
│ │ └── t10k-labels-idx1-ubyte
│ └── train
│ ├── train-images-idx3-ubyte
│ └── train-labels-idx1-ubyte
└── main.py
```
## 实验步骤
### 导入MindSpore模块和辅助模块
```python
import matplotlib.pyplot as plt
import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore import nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
```
### 数据处理
在使用数据集训练网络前,首先需要对数据进行预处理,如下:
```python
DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
return ds
```
对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。
```python
def show_dataset():
ds = create_dataset(training=False)
data = ds.create_dict_iterator().get_next()
images = data['image']
labels = data['label']
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(images[i][0])
plt.title('Number: %s' % labels[i])
plt.xticks([])
plt.show()
```
![img](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAATsAAAD7CAYAAAAVQzPHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcm0lEQVR4nO3deZRV1Zk28OepQWaBYrIQAkZBIKyICjjE1U3aEDHdaU268QuiTRxCVqKt+aJGErOiMdqxTaL9pfOZDh0ZooKxo+0QtQnNEhLRBis4oSggDhArTIIWU0FVvf3HPexzCupW3enc4ezntxar3nvGXfCy795n2JtmBhGRpKsqdQFERIpBlZ2IeEGVnYh4QZWdiHhBlZ2IeEGVnYh4QZVdBkguIHlbqcshUmg+5XZFVnYk3yG5lWSvyLIrSS4vYbEKiuRnSK4huZfkZpIXlbpMEr+k5zbJO4N8/ojkuyRvKta5K7KyC9QAuLbUhcgWyeoMthkHYBGAmwD0BTABwB9jLpqUj8TmNoB7AYwxs2MBnA3gYpJfjLdkKZVc2f0IwPUk+x25guRIkkayJrJsOckrg/jLJFeSvJvkbpKbSJ4dLN9MchvJWUccdiDJpSSbSK4gOSJy7DHBug9IvhlthQXdhJ+TfIrkXgCfzuB3+y6AX5jZ02bWYmY7zeytLP9+pHIlNrfN7E0z2xtZ1AbgpIz/ZvJQyZVdA4DlAK7Pcf8zALwCYABSragHAUxC6i/+EgA/I9k7sv1MAD8AMBDASwAeAICgu7E0OMZgADMA3EPyE5F9LwZwO4A+AJ4leTHJVzop25nBsV8l2UjyfpJ1Of6eUnmSnNsgOYfkHgBbAPQKjh+7Sq7sAOB7AP6R5KAc9n3bzOabWSuAXwMYDuBWM2s2s98BOIj23zhPmtnvzawZqe7lWSSHA/gbAO8Ex2oxszUAHgbw95F9HzOzlWbWZmYHzGyRmX2yk7INA3ApgL8DMApADwD/msPvKJUrqbkNM7sDqcrxNAD3Afgwh98xaxVd2ZnZWgC/BTAnh923RuL9wfGOXBb99tscOe8eAB8AGApgBIAzgi7DbpK7kfqmPK6jfTO0H8B8M1sfnOufAHwuy2NIBUtwbh8+j5nZi0FZvp/LMbJV0/UmZe9mAGsA/CSy7PA1gZ4APgri6D9QLoYfDoIuQB2A95H6x15hZlM72TfboWVeyWEfSZ4k5vaRagCcmOcxMlLRLTsAMLONSDXVr4ks2w7gTwAuIVlN8nLk/xf6OZLnkDwGqesbq8xsM1LfvqNJXkqyNvgzieTYPM41H8BlJD9OsieAG4PziEeSltskq0h+lWR/pkwGcBWAZXmWPyMVX9kFbkXqQmfUVwDcAGAngE8AeC7PcyxC6pv2AwCnI9Wch5k1AfgsgC8h9W34ZwD/DKBbugORnEnytXTrzWwegF8BWAXgXQDNiCS8eCVRuQ3gCwDeAtAE4H6krkUX5Xo0NXiniPggKS07EZFOqbITES+oshMRL+RV2ZGcFrxCspFkLs8DiZQl5Xby5HyDgqmXftcDmIrUax8vAJhhZq8XrngixafcTqZ8HiqeDGCjmW0CAJIPArgAQNqEOIbdrPtRd9GlFJqwa4eZ5fIqkg+U2xXqAPbioDWzo3X5VHbHo/2rIluQegG5HZKzAcwGgO7oiTN4bh6nlEL5b/vNu6UuQxlTbleoVZb++eR8rtl1VHse1Sc2s7lmNtHMJtamfxZRpJwotxMon8puCyLv1CE1Usf7+RVHpCwotxMon8ruBQCjSJ4QvFP3JQCPF6ZYIiWl3E6gnK/ZmVkLyasBLAFQDWCemXX2TpxIRVBuJ1NeQzyZ2VMAnipQWUTKhnI7efQGhYh4QZWdiHghCSMVl43G6852cdO4g1ntywPhLHQnzwmfXW1rasq/YCKilp2I+EGVnYh4Qd3YAqo7L3zu9JXxj2a177qD+1x83a0XhivUjfVaVc+eLn7v2gkubiuTFzYGvdTi4h6Pri5hSbqmlp2IeEGVnYh4Qd3YPO37YjgYxuQBL5SwJJIU1YPCkbca/88oFy/52p0uHlbTG+Vg6rrPu/jg/okuPmZJQymK0ym17ETEC6rsRMQLquxExAu6ZpdGVZ8+Lm457aS02116+xMunt03uyHPtrXudfFdWz8brmhp6WBr8cWhccNc/OJ37omsKY/rdFFLx4b5f8nNU1y89cCpLmZLOO5p1aq1LrYi57ladiLiBVV2IuIFdWMj2C18LH3flLEuXvGLubGcb+6u01383hl7I2v2Hr2xSJm7f+Ty8MPiMF5/KMznb5x/WbjNzt0ubNv9YbtjWXNzoYunlp2I+EGVnYh4Qd3YiN3TwztIi2//cWRN+d0FE6kUJ9b0cPG/PD3fxa0Wzlh52Xe/2W6fvvf/T8HLoZadiHhBlZ2IeMH7buy2q8Oh1L99zQMuPqG2cF3XU1bPcPHgu7u7uHrvochWayECADVrNrr4L78628ULfnaXizPJz3R5l6np/7bExdk+MB9VzbBNNbq2V4fbtNayw+WF1GXLjuQ8kttIro0sqyO5lOSG4Gf/eIspUnjKbb9k0o1dAGDaEcvmAFhmZqMALAs+i1SaBVBue6PLbqyZ/Z7kyCMWXwBgShAvBLAcwI0FLFfRHBgYxhf1/jD9hnn4aGfYdD9uRTjOl3W0sRRNueZ2dEa5Hr972cUzbrrexZl0+4as3+9irnyxw22q+/V18Z8WDG23bkrPDZFPHXc/03loT3jcH/50pouf+FbpxuTL9QbFEDNrBIDg5+B0G5KcTbKBZMMhFP6paJECU24nVOx3Y81srplNNLOJtSiTWUJECkC5XVlyvRu7lWS9mTWSrAewrZCFitsHl53l4jOnvRrLOaasDWcIG/6EnvCpIGWV29F3RPN50NY+Fc5MtuHy8L99VbdWFzec/v/b7dO/Oruua9T6A/UuPm7eSy4+r8+3XBydIW3kml3t9m/L+czp5fq/8HEAs4J4FoDHClMckZJTbidUJo+eLAbwPICTSW4heQWAOwBMJbkBwNTgs0hFUW77JZO7sTPSrDq3wGWJ1e5/CLuu478aPsA7/2N/iP3c2yeEf83HHhuWo9+vno/93JJeUnI7ndYpp7l4y9fDB9jfPmdBmj16plmemQeaBrj4vsc/7eKR+8I8H/bD5zrcN45u65F0MUlEvKDKTkS84M27sT0uaXRxMbquy8c/Gn4YH4bXNYZdi9V7wgm2ez6yKvYySTIdPC+cnHrP0FoXN52/x8VvnHNf7OVYs2eEiwevKUbHNDtq2YmIF1TZiYgXvOnGlouf1K9x8W237HPxHx7JfggeEQCw63a4+IXo5ZMii+b23Nv/7OLfROZE1ryxIiIxU2UnIl5IdDe2ekCdi7vXHOpky9LoVhWWqXrQcBe37gi7JTANBCWVJzqy8ZRF4Tu30Xlj29ZvcnExurRq2YmIF1TZiYgXEt2NPf6pcHicu49/KrKmPO58XtP/DRePem6ri+eeHb4/27p9e1HLJFJo6eaNvfriq1zMlS8hbmrZiYgXVNmJiBdU2YmIFxJ9zW5Ej50u7l3V9XW6S96Z4uLXfzXWxWu+9/OCluuwbgxf2j6/Zzgs9YZnwlvyy74cXr+zBk2kLUfr8c0wt0+5LRyi7+XJi0tRnKOkmyTbasIZ0uKfIlstOxHxhCo7EfFCIrqx6Sb6vajvLyNbdT1T0pY9/Vx83MMbXTyp+Wtp97nh24vC8+UxyXa0S3tD3Vsu/l2vv3CxvpmkI21rw0eYhvwonEVs0uj0eZutv7g6HG8x+sJ/JdH/HxHxgio7EfFCIrqx6BbOtvvghHtdHL3zk63omwt189O/xfDDXjNdfPPAcHl08u1iDAMvArR/E6FuZeGO+8cZI8MPSe3GkhxO8hmS60i+RvLaYHkdyaUkNwQ/+8dfXJHCUW77JZNubAuA68xsLIAzAVxFchyAOQCWmdkoAMuCzyKVRLntkUwmyW4E0BjETSTXATgewAUApgSbLQSwHMCNsZSySP56aNj1nHfLeSUsiRRDueR21Slj231+5wvl15C8fOiSUhchb1ndoCA5EsCpAFYBGBIky+GkGZxmn9kkG0g2HEJzR5uIlJxyO/kyruxI9gbwMIBvmNlHme5nZnPNbKKZTaxFt653ECky5bYfMrobS7IWqWR4wMweCRZvJVlvZo0k6wFsi6uQxRJ9mPeG2feUsCRSLOWQ2ztO69fu8zrlXiwyuRtLAPcCWGdmd0VWPQ5gVhDPAvBY4YsnEh/ltl8yadl9CsClAF4lefghnu8AuAPAQySvAPAegOnxFFEkNsptj2RyN/ZZpB+B5dzCFkekeEqZ2zXDh7m4aWQxBjgSvS4mIl5QZSciXkjGu7Ft4UTSbx6KPhIV3kQbXhPW65mMWlwMzRZOkr3pUMeTeLNFk2Qn0aYrPubiN76SvLuv5ZjbatmJiBdU2YmIFxLRjW3dscPF0QmmURXe5WpbHI4E/F9jnixKubry011jXPzMuSd2uE3VznCSHXVopVKUY26rZSciXlBlJyJeSEQ3FhY2gqMjDEc1t4wsUmGOdsrqcC7PwXeHd4Kr94Z3qWyr5oT1ycfvfc/FY/j1dusq9e5sdN7lHdeED02XS26rZSciXlBlJyJeSEY3NgP8STgbzqShhZtPMxND1u8Py7HyRRfr7qq/WjZvcfGJC9v/NxyDsFtb7l3aqes+7+KWO4e4+JiGhlIUp1Nq2YmIF1TZiYgXVNmJiBe8uWZ3zJLwGkJdCcshcqSWTe+0+3ziL1tcPNa+jnI26KWwrD2WrC5hSbqmlp2IeEGVnYh4wZturEiliD6W8rFbtnSypWRDLTsR8YIqOxHxQibzxnYnuZrkyyRfI/n9YHkdyaUkNwQ/+8dfXJHCUW77JZOWXTOAvzKzUwBMADCN5JkA5gBYZmajACwLPotUEuW2R7qs7CxlT/CxNvhjAC4AsDBYvhDAhbGUUCQmym2/ZHTNjmR1MGP6NgBLzWwVgCFm1ggAwc/BnR1DpBwpt/2RUWVnZq1mNgHAMACTSY7P9AQkZ5NsINlwCM25llMkFsptf2R1N9bMdgNYDmAagK0k6wEg+LktzT5zzWyimU2sRbc8iysSD+V28mVyN3YQyX5B3APAZwC8AeBxALOCzWYBeCyuQorEQbntl0zeoKgHsJBkNVKV40Nm9luSzwN4iOQVAN4DMD3GcorEQbntkS4rOzN7BcCpHSzfCeDcOAolUgzKbb/QrHiDg5PcDuDdop1QOjPCzAaVuhBJodwuG2nzuqiVnYhIqejdWBHxgio7EfGCKrsMkFxA8rZSl0Ok0HzK7Yqs7Ei+Q3IryV6RZVeSXF7CYhUMyTtJbib5Ecl3Sd5U6jJJcSQ9twGA5GdIriG5N8jzi4px3oqs7AI1AK4tdSGyFTzT1ZV7AYwxs2MBnA3gYpJfjLdkUkYSm9skxwFYBOAmAH2RGm3mjzEXDUBlV3Y/AnD94Sfgo0iOJGkkayLLlpO8Moi/THIlybtJ7ia5ieTZwfLNJLeRnHXEYQcGY5s1kVxBckTk2GOCdR+QfDP6TRV0E35O8imSewF8uqtfzMzeNLO9kUVtAE7K+G9GKl1icxvAdwH8wsyeNrMWM9tpZm9l+feTk0qu7BqQepfx+hz3PwPAKwAGIPVN8yCASUhVKpcA+BnJ3pHtZwL4AYCBAF4C8AAABN2NpcExBgOYAeAekp+I7HsxgNsB9AHwLMmLSb7SWeFIziG5B8AWAL2C44sfkpzbZwbHfpVkI8n7SRZldtNKruwA4HsA/pFkLg/Hvm1m882sFcCvAQwHcKuZNZvZ7wAcRPvW1JNm9nsza0aqCX4WyeEA/gbAO8GxWsxsDYCHAfx9ZN/HzGylmbWZ2QEzW2Rmn+yscGZ2B1IJdBqA+wB8mMPvKJUrqbk9DMClAP4OwCgAPQD8aw6/Y9YqurIzs7UAfovcRpLdGon3B8c7cln0229z5Lx7AHwAYCiAEQDOCLoMu0nuRuqb8riO9s1GMLjki0FZvp/LMaQyJTi39wOYb2brg3P9E4DPZXmMnCRhKsWbAawB8JPIssPXu3oC+CiIo/9AuRh+OAi6AHUA3kfqH3uFmU3tZN98X1OpAXBinseQypPE3H4lh30KoqJbdgBgZhuRaqpfE1m2HcCfAFzC1Ei0lyP/yuJzJM8heQxS1zdWmdlmpL59R5O8lGRt8GcSybG5nIRkFcmvkuzPlMkArkJqLgTxSNJyOzAfwGUkP06yJ4Abg/PEruIru8CtSF3Ej/oKgBsA7ATwCQDP5XmORUh9034A4HSkmvMwsyYAnwXwJaS+Df8M4J+B9KM5kpxJ8rVOzvUFAG8BaAJwP1LXNIpyXUPKTqJy28zmAfgVgFVIDZzQjEhlHicNBCAiXkhKy05EpFOq7ETEC3lVdiSnBU9VbySpiYQlMZTbyZPzNTum3oNbD2AqUk/5vwBghpm9XrjiiRSfcjuZ8nnObjKAjWa2CQBIPojUTOppE+IYdrPuR91YklJowq4dGpY9LeV2hTqAvThozexoXT6V3fFo//T0FqTeyUurO3rhDGoek3Lw3/YbzZeQnnK7Qq2y9I+j5lPZdVR7HtUnJjkbwGwA6I6eeZxOpGiU2wmUzw2KLYi8ZoLUC77vH7mRZk2XCqTcTqB8KrsXAIwieULwmsmXkJpJXaTSKbcTKOdurJm1kLwawBIA1QDmmVlnr0CJVATldjLlNeqJmT0F4KkClUWkbCQttw+eN9HFdt2OjPbp8c3uLm5b+0bBy1RseoNCRLygyk5EvJCEwTtFpAt7hta6+IXxj2a0z9QBl7k4Ca2iJPwOIiJdUmUnIl5IdDd229Vnu/jAwPjPN/I/d7m47eV18Z9QRDKmlp2IeEGVnYh4IRHdWHYL30vcPf1UF3/7mgdcfFHv+OeYPmHwbBcPfOEsF/dfv9/FXPlS7OUQkaOpZSciXlBlJyJeUGUnIl5IxDW7qn59XTz/trtcPPaY4g6o+PaFc8MPF4bhKatnuHjoh2NcnISXq6V81Qwf5uKmkR2OVO4VtexExAuq7ETEC4noxpa7lycvdvGUu8L+bbfPlqI0kmRVffq4eP1V4cjyG/7hnlIUp6yoZSciXlBlJyJeUDdWJEHevGOci//nb38cWaMJvNWyExEvqLITES8kohvbtvMDF1878+sutpqOH6Tc9n8PuDh6p1Sk0ln3VhcPrlbXNarLlh3JeSS3kVwbWVZHcinJDcHP/vEWU6TwlNt+yaQbuwDAtCOWzQGwzMxGAVgWfBapNAug3PZGl91YM/s9yZFHLL4AwJQgXghgOYAbC1iurFhLi4uj48WlextwSMsEF08a/bUuj9/SKzzSE9+6s926YTW9MyyllJtKyO24bWnZ4+LP3/mtduvqX9/g4lZUvlxvUAwxs0YACH4OLlyRREpKuZ1Qsd+gIDkbwGwA6I7ijkIiEifldmXJtbLbSrLezBpJ1gPYlm5DM5sLYC4AHMs6y/F8BRXt6tat7Hr76iHhl3vT9XpaJ+EqOrez1dQW5nP9f2xst651+/ZiFydWuf7PfRzArCCeBeCxwhRHpOSU2wmVyaMniwE8D+BkkltIXgHgDgBTSW4AMDX4LFJRlNt+yeRu7Iw0q84tcFnKSnSU1+hQOQOqK7K3Ih3wNbd9pQtQIuIFVXYi4oVEvBubLftU+FDxrtE9OtwmOkFJ+1Fe9b6h+OfgeRNdvGdobVb7Vh8KL/30+48XXWzNzfkXLAtq2YmIF1TZiYgXvOzGbrg8/LXfPv/nRT33sN67Xbxj4ngXW8PajjYX6VLV+HAu4mMH7O1y+22t4TZ3bY3M+hR5xxwAGMnP428OHzi+f+TyrMr39qHw/dsvf/hNF/dcvs7FbU1NWR0zF2rZiYgXVNmJiBe87MaWUrQLcNu8sPvxh092L0FpJAn23xUZeXv8o11uP3fX6S5+78x9Lq4e2H4wg3MXPO/iG+reyrl8J9SGw6Ct+MVcF0+dcZmLq1a8iLipZSciXlBlJyJeUDdWxGPVAwe6ePZzz7dbd37PXZFP2T1IXI7UshMRL6iyExEvqLITES/omp2Iz6rCAS9Orm0/An03dj2vximrwyEBm18Op9h94yv3dLR5O9P/bYmL77vp8+3W9XxkVZf7Z0stOxHxgio7EfGCl93YUfPCF54nPdv1JNmdueHbi1x8Ue8P8zqWSDFc1PePLn528YkuHl6TWdtnzLOXuvhj/6/axbtGZzdlwey+77v43/u2P3ccE1OqZSciXlBlJyJe8LIbm+0k2VV9+rj4zTvGtVs3snZH5FPlP2UuyTe6Npxa4L/GPBlZk34wimjXddg9YZ5z5ZrIgc8qSPniksm8scNJPkNyHcnXSF4bLK8juZTkhuBn/66OJVJOlNt+yaQb2wLgOjMbC+BMAFeRHAdgDoBlZjYKwLLgs0glUW57JJNJshsBNAZxE8l1AI4HcAGAKcFmCwEsB3BjLKUsMfYMZyB78q/vbrdu7DFx3DeSYlBuZ67P0+GYdAf7tbp4+y1nu9jGZTe0evSB5CHr9+dRusxkdYOC5EgApwJYBWBIkCyHk2ZwoQsnUizK7eTLuLIj2RvAwwC+YWYfZbHfbJINJBsOobjzRIpkQrnth4zuxpKsRSoZHjCzR4LFW0nWm1kjyXoA2zra18zmApgLAMeyLrunDjPEbt1cvHv6qS5urWVHm2etpVd4nD5VbXkda/n+8Pvl3tXnuHg0GvI6ruSm3HM7E++vqXfxAyMGuHhmn50FO8eOSWHX9aSTG128buwTOR9z8N3h3V+uLINh2UkSwL0A1pnZXZFVjwOYFcSzADxW+OKJxEe57ZdMWnafAnApgFdJHn5A7TsA7gDwEMkrALwHYHo8RRSJjXLbI5ncjX0WQLr+4LmFLU7mog/67psy1sWLb/+xi6OzGhVOfse85a2/dfHoK9V1LaVyze1snTAnHE79u4O+4OKZ5/+yYOd4+8K5XW+URrMdcvFPd4Uz6lXvDZcX4xqAXhcTES+oshMRL1Tsu7Etp53k4ujEu/l2M+OwqzWciHjXvvAB5eNKURhJNB4Ih1xadzDMu+hTBMNq4v8/Eu26Pr0vfNvumU9/3MW2fW3s5YhSy05EvKDKTkS8ULHd2EoyccVVLj756k0ubu1oY5E8nDzndRdfd+uFLm6cHl72efE7XU+Gk6/oXddo17V1x46ONi8KtexExAuq7ETEC+rGFkFbc3iHrHW3JuWR+LQ1RYZZisT1vw4f25366mWxl6PdA8NFvuuajlp2IuIFVXYi4oWK7cbWvr7FxZNu6nju11LO6RqdoCQ6T61IKbRu3+7iqhXbO9myMEo23lUn1LITES+oshMRL6iyExEvVOw1u+g1iLr5HV+D+GGvmS6+eWDsRWpn2PIDLm43kbCIlIRadiLiBVV2IuKFiu3GZmLwz54rdRFEpEyoZSciXlBlJyJeUGUnIl7IZJLs7iRXk3yZ5Gskvx8sryO5lOSG4Gf/ro4lUk6U237JpGXXDOCvzOwUABMATCN5JoA5AJaZ2SgAy4LPIpVEue2RLis7S9kTfKwN/hiACwAsDJYvBHBhB7uLlC3ltl8yumZHsprkSwC2AVhqZqsADDGzRgAIfg6Or5gi8VBu+yOjys7MWs1sAoBhACaTHJ/pCUjOJtlAsuEQmnMtp0gslNv+yOpurJntBrAcwDQAW0nWA0Dwc1uafeaa2UQzm1iLbnkWVyQeyu3ky+Ru7CCS/YK4B4DPAHgDwOMAZgWbzQLwWFyFFImDctsvmbwuVg9gIclqpCrHh8zstySfB/AQySsAvAdgeozlFImDctsjNCveAMoktwN4t2gnlM6MMLNBpS5EUii3y0bavC5qZSciUip6XUxEvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLygyk5EvKDKTkS8oMpORLzwv9NPrlrn6D7QAAAAAElFTkSuQmCC)
### 定义模型
MindSpore model_zoo中提供了多种常见的模型,可以直接使用。这里使用其中的LeNet5模型,模型结构如下图所示:
![img](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg)
图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf
### 训练
使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。
| batch size | number of epochs | learning rate | optimizer |
| ---------: | ---------------: | ------------: | -----------: |
| 32 | 3 | 0.01 | Momentum 0.9 |
```python
def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
loss_cb = LossMonitor(per_print_times=1)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics)
```
### 实验结果
1. 在训练日志中可以看到`epoch: 1 step: 1875, loss is 0.29772663`等字段,即训练过程的loss值;
2. 在训练日志中可以看到`Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052}`字段,即训练完成后的验证精度。
```python
...
>>> epoch: 1 step: 1875, loss is 0.29772663
...
>>> epoch: 2 step: 1875, loss is 0.049111396
...
>>> epoch: 3 step: 1875, loss is 0.08183163
>>> Metrics: {'loss': 0.06830393138807267, 'acc': 0.9785657051282052}
```
## 实验小结
本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。
\ No newline at end of file
# LeNet5 mnist
import matplotlib.pyplot as plt
import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore import nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
return ds
def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
loss_cb = LossMonitor(per_print_times=1)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics)
def show_dataset():
ds = create_dataset(training=False)
data = ds.create_dict_iterator().get_next()
images = data['image']
labels = data['label']
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(images[i][0])
plt.title('Number: %s' % labels[i])
plt.xticks([])
plt.show()
if __name__ == "__main__":
show_dataset()
test_train()
\ No newline at end of file
# Save and load model
import matplotlib.pyplot as plt
import numpy as np
import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore import nn, Tensor
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1 / (255 * 0.3081), shift=-0.1307 / 0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
# define map operations
resize_op = CV.Resize(resize)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(num_epoch)
return ds
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.flatten(output)
output = self.fc1(output)
output = self.fc2(output)
output = self.fc3(output)
return output
def test_train(lr=0.01, momentum=0.9, num_epoch=2, check_point_name="b_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=check_point_name, config=ckpt_cfg)
loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics)
CKPT = 'b_lenet-2_1875.ckpt'
def resume_train(lr=0.001, momentum=0.9, num_epoch=2, ckpt_name="b_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
param_dict = load_checkpoint(CKPT)
load_param_into_net(net, param_dict)
load_param_into_net(opt, param_dict)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)
loss_cb = LossMonitor(steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics)
def plot_images(pred_fn, ds, net):
for i in range(1, 5):
pred, image, label = pred_fn(ds, net)
plt.subplot(2, 2, i)
plt.imshow(np.squeeze(image))
color = 'blue' if pred == label else 'red'
plt.title("prediction: {}, truth: {}".format(pred, label), color=color)
plt.xticks([])
plt.show()
CKPT = 'b_lenet_1-2_1875.ckpt'
def infer(ds, model):
data = ds.get_next()
images = data['image']
labels = data['label']
output = model.predict(Tensor(data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
return pred[0], images[0], labels[0]
def test_infer():
ds = create_dataset(training=False, batch_size=1).create_dict_iterator()
net = LeNet5()
param_dict = load_checkpoint(CKPT, net)
model = Model(net)
plot_images(infer, ds, model)
if __name__ == "__main__":
test_train()
resume_train()
test_infer()
\ No newline at end of file
# 基于LeNet5的手写数字识别
## 实验介绍
LeNet5 + MINST被誉为深度学习领域的“Hello world”。本实验主要介绍使用MindSpore在MNIST数据集上开发和训练一个LeNet5模型,并验证模型精度。
## 实验目的
- 了解如何使用MindSpore进行简单卷积神经网络的开发。
- 了解如何使用MindSpore进行简单图片分类任务的训练。
- 了解如何使用MindSpore进行简单图片分类任务的验证。
## 预备知识
- 熟练使用Python,了解Shell及Linux操作系统基本知识。
- 具备一定的深度学习理论知识,如卷积神经网络、损失函数、优化器,训练策略等。
- 了解华为云的基本使用方法,包括[OBS(对象存储)](https://www.huaweicloud.com/product/obs.html)[ModelArts(AI开发平台)](https://www.huaweicloud.com/product/modelarts.html)[Notebook(开发工具)](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0033.html)[训练作业](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0046.html)等服务。华为云官网:https://www.huaweicloud.com
- 了解并熟悉MindSpore AI计算框架,MindSpore官网:https://www.mindspore.cn
## 实验环境
- MindSpore 0.5.0(MindSpore版本会定期更新,本指导也会定期刷新,与版本配套);
- 华为云ModelArts:ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。ModelArts官网:https://www.huaweicloud.com/product/modelarts.html
- Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,或Atlas Ascend服务器等。
## 实验准备
### 创建OBS桶
本实验需要使用华为云OBS存储实验脚本和数据集,可以参考[快速通过OBS控制台上传下载文件](https://support.huaweicloud.com/qs-obs/obs_qs_0001.html)了解使用OBS创建桶、上传文件、下载文件的使用方法。
> **提示:** 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考[获取访问密钥并完成ModelArts全局配置](https://support.huaweicloud.com/prepare-modelarts/modelarts_08_0002.html)获取并配置访问密钥。
创建OBS桶的参考配置如下:
- 区域:华北-北京四
- 数据冗余存储策略:单AZ存储
- 桶名称:全局唯一的字符串
- 存储类别:标准存储
- 桶策略:公共读
- 归档数据直读:关闭
- 企业项目、标签等配置:免
### 数据集准备
MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。MNIST数据集的官网:[THE MNIST DATABASE](http://yann.lecun.com/exdb/mnist/)
从MNIST官网下载如下4个文件到本地并解压:
```
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
```
### 脚本准备
[课程gitee仓库](https://gitee.com/mindspore/course)上下载本实验相关脚本。
### 上传文件
将脚本和数据集上传到OBS桶中,组织为如下形式:
```
lenet5
├── MNIST
│   ├── test
│   │   ├── t10k-images-idx3-ubyte
│   │   └── t10k-labels-idx1-ubyte
│   └── train
│   ├── train-images-idx3-ubyte
│   └── train-labels-idx1-ubyte
└── main.py
```
## 实验步骤(ModelArts Notebook)
### 创建Notebook
可以参考[创建并打开Notebook](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0034.html)来创建并打开本实验的Notebook脚本。
创建Notebook的参考配置:
- 计费模式:按需计费
- 名称:lenet5
- 工作环境:Python3
- 资源池:公共资源
- 类型:Ascend
- 规格:单卡1*Ascend 910
- 存储位置:对象存储服务(OBS)->选择上述新建的OBS桶中的lenet5文件夹
- 自动停止等配置:默认
> **注意:**
> - 打开Notebook前,在Jupyter Notebook文件列表页面,勾选目录里的所有文件/文件夹(实验脚本和数据集),并点击列表上方的“Sync OBS”按钮,使OBS桶中的所有文件同时同步到Notebook工作环境中,这样Notebook中的代码才能访问数据集。参考[使用Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html)。
> - 打开Notebook后,选择MindSpore环境作为Kernel。
> **提示:** 上述数据集和脚本的准备工作也可以在Notebook环境中完成,在Jupyter Notebook文件列表页面,点击右上角的"New"->"Terminal",进入Notebook环境所在终端,进入`work`目录,可以使用常用的linux shell命令,如`wget, gzip, tar, mkdir, mv`等,完成数据集和脚本的下载和准备。
> **提示:** 请从上至下阅读提示并执行代码框进行体验。代码框执行过程中左侧呈现[\*],代码框执行完毕后左侧呈现如[1],[2]等。请等上一个代码框执行完毕后再执行下一个代码框。
导入MindSpore模块和辅助模块:
```python
import os
# os.environ['DEVICE_ID'] = '0'
import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU
```
### 数据处理
在使用数据集训练网络前,首先需要对数据进行预处理,如下:
```python
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
data_train = os.path.join(data_dir, 'train') # 训练集信息
data_test = os.path.join(data_dir, 'test') # 测试集信息
ds = ms.dataset.MnistDataset(data_train if training else data_test)
ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
# When `dataset_sink_mode=True` on Ascend, append `ds = ds.repeat(num_epochs) to the end
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)
return ds
```
对其中几张图片进行可视化,可以看到图片中的手写数字,图片的大小为32x32。
```python
ds = create_dataset('MNIST', training=False)
data = ds.create_dict_iterator().get_next()
images = data['image']
labels = data['label']
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(images[i][0])
plt.title('Number: %s' % labels[i])
plt.xticks([])
plt.show()
```
![png](images/mnist.png)
### 定义模型
MindSpore model_zoo中提供了多种常见的模型,可以直接使用。LeNet5模型结构如下图所示:
![LeNet5](https://www.mindspore.cn/tutorial/zh-CN/master/_images/LeNet_5.jpg)
[1] 图片来源于http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf
```python
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, x):
x = self.relu(self.conv1(x))
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
```
### 训练
使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,可以调整训练策略并查看训练效果,要求验证精度大于95%。
| batch size | number of epochs | learning rate | optimizer |
| -- | -- | -- | -- |
| 32 | 3 | 0.01 | Momentum 0.9 |
```python
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3):
ds_train = create_dataset(data_dir)
ds_eval = create_dataset(data_dir, training=False)
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
model = Model(net, loss, opt, metrics={'acc', 'loss'})
# dataset_sink_mode can be True when using Ascend
model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics)
train('MNIST')
```
epoch: 1 step 1875, loss is 0.23394052684307098
Epoch time: 23049.360, per step time: 12.293, avg loss: 2.049
************************************************************
epoch: 2 step 1875, loss is 0.4737345278263092
Epoch time: 26768.848, per step time: 14.277, avg loss: 0.155
************************************************************
epoch: 3 step 1875, loss is 0.07734094560146332
Epoch time: 25687.625, per step time: 13.700, avg loss: 0.094
************************************************************
Metrics: {'loss': 0.10531254443608654, 'acc': 0.9701522435897436}
## 实验步骤(ModelArts训练作业)
除了Notebook,ModelArts还提供了训练作业服务。相比Notebook,训练作业资源池更大,且具有作业排队等功能,适合大规模并发使用。使用训练作业时,也会有修改代码和调试的需求,有如下三个方案:
1. 在本地修改代码后重新上传;
2. 使用[PyCharm ToolKit](https://support.huaweicloud.com/tg-modelarts/modelarts_15_0001.html)配置一个本地Pycharm+ModelArts的开发环境,便于上传代码、提交训练作业和获取训练日志。
3. 在ModelArts上创建Notebook,然后设置[Sync OBS功能](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0038.html),可以在线修改代码并自动同步到OBS中。因为只用Notebook来编辑代码,所以创建CPU类型最低规格的Notebook就行。
### 适配训练作业
创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到`args`变量里,在后续代码里可以使用。
```python
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
args, unknown = parser.parse_known_args()
```
MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。将OBS中存储的数据拷贝至执行容器:
```python
import moxing
moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')
```
如需将训练输出(如模型Checkpoint)从执行容器拷贝至OBS,请参考:
```python
import moxing
# dst_url形如's3://OBS/PATH',将ckpt目录拷贝至OBS后,可在OBS的`args.train_url`目录下看到ckpt目录
moxing.file.copy_parallel(src_url='ckpt', dst_url=os.path.join(args.train_url, 'ckpt'))
```
### 创建训练作业
可以参考[使用常用框架训练模型](https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html)来创建并启动训练作业。
创建训练作业的参考配置:
- 算法来源:常用框架->Ascend-Powered-Engine->MindSpore
- 代码目录:选择上述新建的OBS桶中的lenet5目录
- 启动文件:选择上述新建的OBS桶中的lenet5目录下的`main.py`
- 数据来源:数据存储位置->选择上述新建的OBS桶中的lenet5目录下的MNIST目录
- 训练输出位置:选择上述新建的OBS桶中的lenet5目录并在其中创建output目录
- 作业日志路径:同训练输出位置
- 规格:Ascend:1*Ascend 910
- 其他均为默认
启动并查看训练过程:
1. 点击提交以开始训练;
2. 在训练作业列表里可以看到刚创建的训练作业,在训练作业页面可以看到版本管理;
3. 点击运行中的训练作业,在展开的窗口中可以查看作业配置信息,以及训练过程中的日志,日志会不断刷新,等训练作业完成后也可以下载日志到本地进行查看;
4. 参考实验步骤(Notebook),在日志中找到对应的打印信息,检查实验是否成功。
## 实验步骤(本地CPU/GPU/Ascend)
MindSpore还支持在本地CPU/GPU/Ascend环境上运行,如Windows/Ubuntu x64笔记本,NVIDIA GPU服务器,以及Atlas Ascend服务器等。在本地环境运行实验前,需要先参考[安装教程](https://www.mindspore.cn/install/)配置环境。
在Windows/Ubuntu x64笔记本上运行实验:
```shell script
vim main.py # 将第15行的context设置为`device_target='CPU'`
python main.py --data_url=D:\dataset\MNIST
```
在Ascend服务器上运行实验:
```shell script
vim main.py # 将第15行的context设置为`device_target='Ascend'`
python main.py --data_url=/PATH/TO/MNIST
```
## 实验小结
本实验展示了如何使用MindSpore进行手写数字识别,以及开发和训练LeNet5模型。通过对LeNet5模型做几代的训练,然后使用训练后的LeNet5模型对手写数字进行识别,识别准确率大于95%。即LeNet5学习到了如何进行手写数字识别。
# LeNet5 mnist # LeNet5 MNIST
import os import os
# os.environ['DEVICE_ID'] = '0' # os.environ['DEVICE_ID'] = '0'
...@@ -9,52 +9,77 @@ import mindspore.dataset.transforms.c_transforms as C ...@@ -9,52 +9,77 @@ import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore import nn from mindspore import nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.callback import LossMonitor from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') # Ascend, CPU, GPU
DATA_DIR_TRAIN = "MNIST/train" # 训练集信息
DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) data_train = os.path.join(data_dir, 'train') # 训练集信息
data_test = os.path.join(data_dir, 'test') # 测试集信息
ds = ms.dataset.MnistDataset(data_train if training else data_test)
ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
# When `dataset_sink_mode=True` on Ascend, append `ds = ds.repeat(num_epochs) to the end
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)
ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
return ds return ds
def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): class LeNet5(nn.Cell):
ds_train = create_dataset(num_epoch=num_epoch) def __init__(self):
ds_eval = create_dataset(training=False) super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, x):
x = self.relu(self.conv1(x))
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3):
ds_train = create_dataset(data_dir)
ds_eval = create_dataset(data_dir, training=False)
net = LeNet5() net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum) opt = nn.Momentum(net.trainable_params(), lr, momentum)
loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
loss_cb = LossMonitor(per_print_times=1)
model = Model(net, loss, opt, metrics={'acc', 'loss'}) model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[loss_cb]) # dataset_sink_mode can be True when using Ascend
metrics = model.eval(ds_eval) model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
metrics = model.eval(ds_eval, dataset_sink_mode=False)
print('Metrics:', metrics) print('Metrics:', metrics)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data_url', required=True, default=None, help='Location of data.') parser.add_argument('--data_url', required=False, default='MNIST', help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
parser.add_argument('--num_epochs', type=int, default=1, help='Number of training epochs.')
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
import moxing as mox if args.data_url.startswith('s3'):
mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') import moxing
moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST')
args.data_url = 'MNIST'
test_train() train(args.data_url)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册