提交 90d94759 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3 简化experiment_1的训练代码,去掉部分高级功能,展示简介高效的效果

Merge pull request !3 from dyonghan/experiment_1
...@@ -134,3 +134,4 @@ dmypy.json ...@@ -134,3 +134,4 @@ dmypy.json
# IDE # IDE
.idea/ .idea/
.vscode/
此差异已折叠。
...@@ -3,19 +3,14 @@ ...@@ -3,19 +3,14 @@
import os import os
# os.environ['DEVICE_ID'] = '0' # os.environ['DEVICE_ID'] = '0'
import matplotlib.pyplot as plt
import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.context as context import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore.dataset.transforms.vision import Inter from mindspore import nn
from mindspore import nn, Tensor
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
...@@ -26,26 +21,16 @@ DATA_DIR_TEST = "MNIST/test" # 测试集信息 ...@@ -26,26 +21,16 @@ DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32), def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64): rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST) ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
# define map operations
resize_op = CV.Resize(resize)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32)) ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
ds = ds.shuffle(buffer_size=buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(num_epoch)
return ds return ds
class LeNet(nn.Cell): class LeNet5(nn.Cell):
def __init__(self): def __init__(self):
super(LeNet, self).__init__() super(LeNet5, self).__init__()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid') self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
...@@ -70,26 +55,22 @@ class LeNet(nn.Cell): ...@@ -70,26 +55,22 @@ class LeNet(nn.Cell):
return output return output
LOOP_SINK = context.get_context('enable_loop_sink')
def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"): def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"):
ds_train = create_dataset(num_epoch=num_epoch) ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False) ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet() net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum) opt = nn.Momentum(net.trainable_params(), lr, momentum)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5) loss_cb = LossMonitor(per_print_times=1)
ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)
loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)
model = Model(net, loss, opt, metrics={'acc', 'loss'}) model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True) model.train(num_epoch, ds_train, callbacks=[loss_cb])
metrics = model.eval(ds_eval) metrics = model.eval(ds_eval)
print('Metrics:', metrics) print('Metrics:', metrics)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -101,6 +82,4 @@ if __name__ == "__main__": ...@@ -101,6 +82,4 @@ if __name__ == "__main__":
import moxing as mox import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/') mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')
os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件
test_train() test_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册