提交 90d94759 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3 简化experiment_1的训练代码,去掉部分高级功能,展示简介高效的效果

Merge pull request !3 from dyonghan/experiment_1
......@@ -134,3 +134,4 @@ dmypy.json
# IDE
.idea/
.vscode/
此差异已折叠。
......@@ -3,19 +3,14 @@
import os
# os.environ['DEVICE_ID'] = '0'
import matplotlib.pyplot as plt
import numpy as np
import mindspore as ms
import mindspore.context as context
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore.dataset.transforms.vision import Inter
from mindspore import nn, Tensor
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
......@@ -26,26 +21,16 @@ DATA_DIR_TEST = "MNIST/test" # 测试集信息
def create_dataset(training=True, num_epoch=1, batch_size=32, resize=(32, 32),
rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
# define map operations
resize_op = CV.Resize(resize)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
ds = ds.map(input_columns="image", operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
ds = ds.shuffle(buffer_size=buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(num_epoch)
ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True).repeat(num_epoch)
return ds
class LeNet(nn.Cell):
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
super(LeNet5, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
......@@ -70,26 +55,22 @@ class LeNet(nn.Cell):
return output
LOOP_SINK = context.get_context('enable_loop_sink')
def test_train(lr=0.01, momentum=0.9, num_epoch=3, ckpt_name="a_lenet"):
ds_train = create_dataset(num_epoch=num_epoch)
ds_eval = create_dataset(training=False)
steps_per_epoch = ds_train.get_dataset_size()
net = LeNet()
net = LeNet5()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
opt = nn.Momentum(net.trainable_params(), lr, momentum)
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix=ckpt_name, config=ckpt_cfg)
loss_cb = LossMonitor(per_print_times=1 if LOOP_SINK else steps_per_epoch)
loss_cb = LossMonitor(per_print_times=1)
model = Model(net, loss, opt, metrics={'acc', 'loss'})
model.train(num_epoch, ds_train, callbacks=[ckpt_cb, loss_cb], dataset_sink_mode=True)
model.train(num_epoch, ds_train, callbacks=[loss_cb])
metrics = model.eval(ds_eval)
print('Metrics:', metrics)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
......@@ -101,6 +82,4 @@ if __name__ == "__main__":
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')
os.system('rm -f *.ckpt *.ir *.meta') # 清理旧的运行文件
test_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册