From a0d93d56f18ac90c77aadf6bbda8c4231d03f254 Mon Sep 17 00:00:00 2001 From: yuejich Date: Sun, 16 May 2021 14:40:17 +0800 Subject: [PATCH] =?UTF-8?q?Add=20new=20file=20=E8=BF=99=E4=B8=AA=E6=98=AF?= =?UTF-8?q?=E8=AE=AD=E7=BB=83python=E6=BA=90=E7=A0=81=20=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?cnn=E5=8D=B7=E7=A7=AF=E7=A5=9E=E7=BB=8F=E7=BD=91=E7=BB=9C?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=EF=BC=8C=E4=BB=A5MNIST=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E4=B8=BA=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=EF=BC=8C?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_main.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 train_main.py diff --git a/train_main.py b/train_main.py new file mode 100644 index 0000000..29b1374 --- /dev/null +++ b/train_main.py @@ -0,0 +1,109 @@ +import paddle +import paddle.fluid as fluid +from decode_binary_function import * #这里导入解码二进制函数模块 +from cnn_function import * +from cnn_net import ConvolutionalNeuralNetwork +import numpy as np +import struct +import matplotlib.pyplot as plt + +# 训练集文件 +train_images_idx3_ubyte_file = r'./MNIST/raw/train-images-idx3-ubyte' +# 训练集标签文件 +train_labels_idx1_ubyte_file = r'./MNIST/raw/train-labels-idx1-ubyte' + +# 测试集文件 +test_images_idx3_ubyte_file = r'./MNIST/raw/t10k-images-idx3-ubyte' +# 测试集标签文件 +test_labels_idx1_ubyte_file = r'./MNIST/raw/t10k-labels-idx1-ubyte' + + +#train_set为60000条训练数据集图片,每张图片为28*28像素,数据集+数据集标签ndarray +train_x_ori=decode_idx3_ubyte(train_images_idx3_ubyte_file) +train_y_set=decode_idx1_ubyte(train_labels_idx1_ubyte_file) +print('数据集train_x_ori形状:',train_x_ori.shape) +print(type(train_x_ori)) +print('数据集标签train_y_set形状:',train_y_set.shape) +print(type(train_y_set)) +m_train_x_ori=train_x_ori.shape[0] +print('训练集含有图片数量:',m_train_x_ori) + +train_x_flatten=train_x_ori.reshape(m_train_x_ori,-1)#把train_x_ori展开为60000行 +train_x_set=train_x_flatten/255#数据归一化,因为每个像素为0-255,除以255让每一个数据都小于1 +train_set=np.hstack((train_x_set,train_y_set))#60000行数据和60000行标签合并 +print('合并后训练集形状:',train_set.shape) +print('合并后训练集类型:',type(train_set)) +#print(train_set[0])#其中前784个为图片信息,第785个数据是标签信息 + +#定义飞桨动态图工作环境 +with fluid.dygraph.guard(): + #定义模型名称为mnist,调用了自己写的卷积神经网络 + model = ConvolutionalNeuralNetwork('mnist') + #定义优化器optimizer,利用adam优化器,学习率0.001,参数列表为模型参数 + opt=fluid.optimizer.Adam(learning_rate=0.001,parameter_list=model.parameters()) + #迭代次数5次 + EPOCH_NUM=50 + +#train_reader是一个生成器,类型为function +train_reader=read_data(train_set)#read_data()是一个生成器,会不断读取数据集里面的数据yeild +print('观察yeild生成器返回对象类型:',type(train_reader))#生成器不加括号是function类型,加上括号是generation类型 +#train_reader1为generate类型,生成器生成的数据分批次读取,每个批次是16张图片 +train_reader1=paddle.batch(train_reader,batch_size=16) +#观察一下paddle.batch返回类型 +print('paddle.batch返回类型:',type(train_reader1))#为function + + +#训练模型 +with fluid.dygraph.guard(): + # 定义外层循环 + for pass_num in range(EPOCH_NUM): + # 定义内层循环 + for batch_id,data in enumerate(train_reader1()):#这里的train_reader1加上了括号,就是一个生成器,总共60000张图片,分批生成 + #观察data的结构至关重要,train_reader1 yeild的了两个数据,一个是前784个数据是图片,第785个数据是label,所以 + #data是一个列表,列表长度是16,data[0]-data[15]共16个列表元素,data[1]是一个元组, + #print('data的类型:',type(data)) + #print('data的长度:',len(data)) + #print('data列表中元素的类型:',type(data[0])) + #print(data[0]) + images = np.array([x[0].reshape(1, 28, 28) for x in data],np.float32) + labels = np.array([x[1] for x in data]).astype('int64').reshape(-1, 1) + + #print('images类型:',type(images)) + #print(images.shape) + #print('labels类型:', type(labels)) + #print(labels.shape) + #print(images[0]) + #y=images[0].reshape(28,28) + #plt.imshow(y) + #plt.show() + #print(labels) + #print(len(data_label)) + #print(data_label) + + # 将numpy数据转为飞桨动态图variable形式 + image = fluid.dygraph.to_variable(images) + label = fluid.dygraph.to_variable(labels) + # 第一步:前向计算 + predict = model(image) + # 第二步:计算损失 + loss = fluid.layers.cross_entropy(predict,label) + avg_loss = fluid.layers.mean(loss) + # 第三步:计算精度 + acc = fluid.layers.accuracy(predict,label) + #第四步:打印数据 + if batch_id % 500 == 0: + print("pass:{},batch_id:{},train_loss:{},train_acc:{}". + format(pass_num,batch_id,avg_loss.numpy(),acc.numpy())) + + # 第五步:反向传播 + avg_loss.backward() + # 最小化loss,更新参数 + opt.minimize(avg_loss) + # 清除梯度 + model.clear_gradients() + # 保存模型文件到指定路径 + fluid.save_dygraph(model.state_dict(), 'mnist') + + + + -- GitLab