提交 b843fa2d 编写于 作者: Y yangyaqin1@huawei.com

feedforward fashion-mnist

上级 18e24a9c
此差异已折叠。
import os
import struct
import sys
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import numpy as np
import mindspore
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
def read_image(file_name):
'''
:param file_name: 文件路径
:return: 训练或者测试数据
如下是训练的图片的二进制格式
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
'''
file_handle = open(file_name, "rb") # 以二进制打开文档
file_content = file_handle.read() # 读取到缓冲区中
head = struct.unpack_from('>IIII', file_content, 0) # 取前4个整数,返回一个元组
offset = struct.calcsize('>IIII')
imgNum = head[1] # 图片数
width = head[2] # 宽度
height = head[3] # 高度
bits = imgNum * width * height # data一共有60000*28*28个像素值
bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
imgs = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
imgs_array = np.array(imgs).reshape((imgNum, width * height)) # 最后将读取的数据reshape成 【图片数,图片像素】二维数组
return imgs_array
def read_label(file_name):
'''
:param file_name:
:return:
标签的格式如下:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
'''
file_handle = open(file_name, "rb") # 以二进制打开文档
file_content = file_handle.read() # 读取到缓冲区中
head = struct.unpack_from('>II', file_content, 0) # 取前2个整数,返回一个元组
offset = struct.calcsize('>II')
labelNum = head[1] # label数
bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
return np.array(label)
def get_data():
# 文件获取
train_image = os.path.join(cfg.data_dir_train, 'train-images-idx3-ubyte')
test_image = os.path.join(cfg.data_dir_test, "t10k-images-idx3-ubyte")
train_label = os.path.join(cfg.data_dir_train, "train-labels-idx1-ubyte")
test_label = os.path.join(cfg.data_dir_test, "t10k-labels-idx1-ubyte")
# 读取数据
train_x = read_image(train_image)
test_x = read_image(test_image)
train_y = read_label(train_label)
test_y = read_label(test_label)
return train_x, train_y, test_x, test_y
# 定义前馈神经网络
class Forward_fashion(nn.Cell):
def __init__(self, num_class=10): # 一共分十类,图片通道数是1
super(Forward_fashion, self).__init__()
self.num_class = num_class
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(cfg.channel * cfg.image_height * cfg.image_width, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Dense(128, self.num_class)
self.softmax = nn.Softmax()
def construct(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.softmax(x)
return x
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
args, unknown = parser.parse_known_args()
import moxing as mox
mox.file.copy_parallel(src_url=args.data_url, dst_url='Fashion-MNIST')
cfg = edict({
'train_size': 60000, # 训练集大小
'test_size': 10000, # 测试集大小
'channel': 1, # 图片通道数
'image_height': 28, # 图片高度
'image_width': 28, # 图片宽度
'batch_size': 60,
'num_classes': 10, # 分类类别
'lr': 0.001, # 学习率
'epoch_size': 20, # 训练次数
'data_dir_train': os.path.join('Fashion-MNIST', 'train'),
'data_dir_test': os.path.join('Fashion-MNIST', 'test'),
'save_checkpoint_steps': 1, # 多少步保存一次模型
'keep_checkpoint_max': 3, # 最多保存多少个模型
'output_directory': './model_fashion', # 保存模型路径
'output_prefix': "checkpoint_fashion_forward" # 保存模型文件名字
})
train_x, train_y, test_x, test_y = get_data()
train_x = train_x.reshape(-1, 1, 28, 28)
test_x = test_x.reshape(-1, 1, 28, 28)
train_x = train_x / 255.0
test_x = test_x / 255.0
train_x = train_x.astype('Float32')
test_x = test_x.astype('Float32')
train_y = train_y.astype('int32')
test_y = test_y.astype('int32')
print('训练数据集样本数:', train_x.shape[0])
print('测试数据集样本数:', test_y.shape[0])
print('通道数/图像长/宽:', train_x.shape[1:])
print('一张图像的标签样式:', train_y[0]) # 一共10类,用0-9的数字表达类别。
# 转换数据类型为Dataset
XY_train = list(zip(train_x, train_y))
ds_train = ds.GeneratorDataset(XY_train, ['x', 'y'])
ds_train.set_dataset_size(cfg.train_size)
ds_train = ds_train.shuffle(buffer_size=cfg.train_size).batch(cfg.batch_size, drop_remainder=True).repeat(
cfg.epoch_size)
XY_test = list(zip(test_x, test_y))
ds_test = ds.GeneratorDataset(XY_test, ['x', 'y'])
ds_test.set_dataset_size(cfg.test_size)
ds_test = ds_test.shuffle(buffer_size=cfg.test_size).batch(cfg.batch_size, drop_remainder=True).repeat(cfg.epoch_size)
# 构建网络
network = Forward_fashion(cfg.num_classes)
# 定义模型的损失函数,优化器
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Adam(network.trainable_params(), cfg.lr)
# 训练模型
model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={"acc"})
loss_cb = LossMonitor(per_print_times=int(cfg.train_size / cfg.batch_size))
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.output_prefix, directory=cfg.output_directory, config=config_ck)
print("============== Starting Training ==============")
model.train(cfg.epoch_size, ds_train, callbacks=[ckpoint_cb, loss_cb], dataset_sink_mode=True)
# 使用测试集评估模型,打印总体准确率
metric = model.eval(ds_test)
print(metric)
# 预测
test_ = ds_test.create_dict_iterator().get_next()
test = Tensor(test_['x'], mindspore.float32)
predictions = model.predict(test)
predictions = predictions.asnumpy()
for i in range(10):
p_np = predictions[i, :]
p_list = p_np.tolist()
print('第' + str(i) + '个sample预测结果:', p_list.index(max(p_list)), ' 真实结果:', test_['y'][i])
mox.file.copy_parallel(src_url='model_fashion', dst_url=args.train_url)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册