提交 4dbd827b 编写于 作者: lsqtina's avatar lsqtina

test=develop

上级 9fb40dac
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境\n",
"本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。"
]
},
{
"cell_type": "code",
"execution_count": 295,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.0.0-alpha0'"
]
},
"execution_count": 295,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import paddle\n",
"paddle.__version__"
]
},
{
"cell_type": "code",
"execution_count": 296,
"metadata": {},
"outputs": [],
"source": [
"#数据准备\n",
"#数据处理部分之前的代码,加入部分数据处理的库\n",
"import paddle\n",
"from paddle.imperative import to_variable\n",
"import numpy as np\n",
"import os\n",
"import gzip #解压缩包,python自带的包\n",
"import json\n",
"import random\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1.数据读取与数据集划分\n",
"加载json数据文件。"
]
},
{
"cell_type": "code",
"execution_count": 297,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading mnist dataset from /Users/liushuangqiao/Downloads/mnist.json.gz ......\n",
"mnist dataset load done\n",
"训练数据集数量: 50000 50000\n",
"验证数据集数量: 10000 10000\n",
"测试数据集数量: 10000 10000\n"
]
}
],
"source": [
"# 声明数据集文件位置\n",
"datafile = '/Users/liushuangqiao/Downloads/mnist.json.gz'\n",
"print('loading mnist dataset from {} ......'.format(datafile))\n",
"# 加载json数据文件\n",
"data = json.load(gzip.open(datafile))\n",
"print('mnist dataset load done')\n",
"# 读取到的数据区分训练集,验证集,测试集\n",
"train_set, val_set, eval_set = data\n",
"\n",
"# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS\n",
"IMG_ROWS = 28\n",
"IMG_COLS = 28\n",
"\n",
"# 打印数据信息\n",
"imgs, labels = train_set[0], train_set[1]\n",
"print(\"训练数据集数量: \", len(imgs),len(labels))\n",
"\n",
"# 观察验证集数量\n",
"imgs, labels = val_set[0], val_set[1]\n",
"print(\"验证数据集数量: \", len(imgs),len(labels))\n",
"\n",
"# 观察测试集数量\n",
"imgs, labels = val= eval_set[0], eval_set[1]\n",
"print(\"测试数据集数量: \", len(imgs),len(labels))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. 通过DataSet与DataLoader获取数据"
]
},
{
"cell_type": "code",
"execution_count": 298,
"metadata": {},
"outputs": [],
"source": [
"from paddle.io import Dataset\n",
"\n",
"#定义Dataset类对象\n",
"class RandomDataset(Dataset):\n",
" def __init__(self, imgs, labels):\n",
" self.imgs = imgs\n",
" self.labels = labels\n",
" \n",
" def __getitem__(self, idx):\n",
" img = self.imgs[idx]\n",
" label = self.labels[idx]\n",
" return img, label\n",
" \n",
" def __len__(self):\n",
" return len(self.imgs)\n"
]
},
{
"cell_type": "code",
"execution_count": 299,
"metadata": {},
"outputs": [],
"source": [
"#通过DataLoader读取dataset数据,涉及必要参数 :dataset、places=None、batch_size\n",
"def load_data_new(mode='train'):\n",
" datafile = '/Users/liushuangqiao/Downloads/mnist.json.gz'\n",
" print('loading mnist dataset from {} ......'.format(datafile))\n",
" # 定义批大小\n",
" BATCH_SIZE = 64\n",
" # 加载json数据文件\n",
" data = json.load(gzip.open(datafile))\n",
" print('mnist dataset load done')\n",
" # 读取到的数据区分训练集,验证集,测试集\n",
" train_set, val_set, eval_set = data\n",
" if mode=='train':\n",
" # 获得训练数据集\n",
" imgs, labels = train_set[0], train_set[1]\n",
" elif mode=='valid':\n",
" # 获得验证数据集\n",
" imgs, labels = val_set[0], val_set[1]\n",
" elif mode=='eval':\n",
" # 获得测试数据集\n",
" imgs, labels = eval_set[0], eval_set[1]\n",
" else:\n",
" raise Exception(\"mode can only be one of ['train', 'valid', 'eval']\")\n",
" dataset = RandomDataset(imgs, labels)\n",
" loader = paddle.io.DataLoader(dataset, places=paddle.CPUPlace(),batch_size=BATCH_SIZE, drop_last=True)\n",
" return loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. 数据校验"
]
},
{
"cell_type": "code",
"execution_count": 300,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading mnist dataset from /Users/liushuangqiao/Downloads/mnist.json.gz ......\n",
"mnist dataset load done\n",
"[64, 784] [64] <class 'paddle.fluid.core_avx.VarBase'> <class 'paddle.fluid.core_avx.VarBase'>\n",
"\n",
"打印第一个batch的第一个图像,对应标签数字为[5]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQbUlEQVR4nO3dfbBU9X3H8fdHHhV8AB8oIlWC+BgjJndQIzV2Eqk6ddTplErbFBktmpHWTGir0mlk0jqjHaMhjtpiJWKqJFhjJDPmwTBOTKaReDEoqFURMIrXi4gKPgQvl2//2AOz6t3fvezu3V3u7/OauXPPnu85e7539cPZ3d/Z/SkiMLOBb59mN2BmjeGwm2XCYTfLhMNulgmH3SwTDrtZJhz2AUrSM5LOanYf1jrkcXazPPjMbpYJh32AkrRB0pckzZd0v6T/lrRN0mpJx0i6VtImSa9Imla23yxJzxXbrpN0+cfu958kdUh6TdJlkkLS0UVtmKSbJP1OUqek/5C0b6P/duuZw56H84HvAqOA3wI/pfTffhzwDeA/y7bdBPwpcAAwC7hF0mcBJJ0DfA34EnA0cNbHjnMDcAwwuaiPA77eH3+Q7Tm/Zh+gJG0ALgOmAmdExNnF+vOBJcCBEdEtaX9gKzAqIt7u4X5+CDwaEQskLQI6I+LaonY08CIwCXgJeBf4TES8VNRPB+6LiAn9+9daXwxudgPWEJ1lyx8AmyOiu+w2wEjgbUnnAtdROkPvA+wHrC62ORxoL7uvV8qWDy22XSlp1zoBg+r0N1iNHHbbTdIw4AHgb4CHIqKrOLPvSm8HcETZLuPLljdT+ofjxIjY2Ih+bc/4NbuVGwoMA94AdhRn+Wll9aXALEnHS9oP+JddhYjYCdxJ6TX+YQCSxkn6k4Z1b0kOu+0WEduAv6cU6reAvwSWldV/DHwbeBRYCzxelLYXv6/etV7SVuDnwLENad565TforGqSjgfWAMMiYkez+7E0n9ltj0i6qBhPHwXcCPzIQd87OOy2py6nNBb/EtANfKW57Vhf+Wm8WSZ8ZjfLREPH2YdqWAxnRCMPaZaV3/MeH8Z29VSrKezFtdILKF0l9V8RcUNq++GM4FR9sZZDmlnCilhesVb103hJg4DbgHOBE4AZkk6o9v7MrH/V8pp9CrA2ItZFxIfA94AL6tOWmdVbLWEfx0c/CPFqse4jJM2W1C6pvWv3hVZm1mj9/m58RCyMiLaIaBvCsP4+nJlVUEvYN/LRTz0dUawzsxZUS9ifACZJmiBpKHAxZR+aMLPWUvXQW0TskDSH0lccDQIWRcQzdevMzOqqpnH2iHgYeLhOvZhZP/LlsmaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulomaZnG1vcA+g5LlwYcd0q+Hf/4fJlSsdY/Ymdz3yImbkvURV6SP3fGtYRVrT7Z9P7nv5u73kvXTls5N1ifOfTxZb4aawi5pA7AN6AZ2RERbPZoys/qrx5n9jyNicx3ux8z6kV+zm2Wi1rAH8DNJKyXN7mkDSbMltUtq72J7jYczs2rV+jR+akRslHQY8Iik/4uIx8o3iIiFwEKAAzQ6ajyemVWppjN7RGwsfm8CHgSm1KMpM6u/qsMuaYSk/XctA9OANfVqzMzqq5an8WOAByXtup/7IuIndelqgBl07NHJegwfkqx3fGFUsv7+6ZXHhEcfmB4v/uXJ6fHmZvrx+/sn6zfefk6yvuKk+yrW1nd9kNz3hs6zk/XDf5m+RqAVVR32iFgHnFzHXsysH3nozSwTDrtZJhx2s0w47GaZcNjNMuGPuNbBzi+ckqzfvPj2ZP2YIUPr2c5eoyu6k/Wvf/uSZH3Ie+kLMj+/dE7F2siNXcl9h21OD83tu/I3yXor8pndLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEx9nrYOjzryXrK38/Plk/ZkhnPdupq7kdpyXr695NfxX13RP/p2LtnZ3pcfIxt/5vst6fBuJXKvnMbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlQhGNG1E8QKPjVH2xYcdrFW9dcnqy/s656a97HvzUyGR91ZW37nFPu/zb5s8k60/80cHJevfWrekDnFb5/tdfpeSuE2Y8lb5v+4QVsZytsaXHB9ZndrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEx5nbwGDDh6drHe/uSVZX7+k8mS6z5y5KLnvlOv/Llk/7Pbmfabc9lxN4+ySFknaJGlN2brRkh6R9GLxOz2BuJk1XV+ext8NfHzW+2uA5RExCVhe3DazFtZr2CPiMeDjzyMvABYXy4uBC+vcl5nVWbXfQTcmIjqK5deBMZU2lDQbmA0wnP2qPJyZ1armd+Oj9A5fxXf5ImJhRLRFRNsQhtV6ODOrUrVh75Q0FqD4val+LZlZf6g27MuAmcXyTOCh+rRjZv2l19fskpYAZwGHSHoVuA64AVgq6VLgZWB6fzY50PU2jt6brq3Vz+9+4l8/m6y/cUf6M+c08DoNq02vYY+IGRVKvjrGbC/iy2XNMuGwm2XCYTfLhMNulgmH3SwTnrJ5ADj+H5+vWJt1UnrQ5DtHLk/Wz/rzK5P1kUsfT9atdfjMbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwuPsA0Bq2uQ3rzguue/vfvRBsn719fck69f+xUXJejx5YMXa+Ov9NdWN5DO7WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJT9mcuS2zTk/W751/U7I+YfDwqo994uI5yfqkO19L1nesf7nqYw9UNU3ZbGYDg8NulgmH3SwTDrtZJhx2s0w47GaZcNjNMuFxdkuKz5+crB9w48Zkfcmnflr1sY979LJk/dj5byfr3WvXV33svVVN4+ySFknaJGlN2br5kjZKWlX8nFfPhs2s/vryNP5u4Jwe1t8SEZOLn4fr25aZ1VuvYY+Ix4AtDejFzPpRLW/QzZH0dPE0f1SljSTNltQuqb2L7TUczsxqUW3Y7wAmApOBDuCblTaMiIUR0RYRbUMYVuXhzKxWVYU9IjojojsidgJ3AlPq25aZ1VtVYZc0tuzmRcCaStuaWWvodZxd0hLgLOAQoBO4rrg9GQhgA3B5RHT0djCPsw88gw49NFl/7eJJFWsrrlmQ3HefXs5Ff7V+WrL+ztQ3k/WBKDXO3uskERExo4fVd9XclZk1lC+XNcuEw26WCYfdLBMOu1kmHHazTPgjrtY0S1/9dbK+n4Ym6+/Hh8n6+XOuqljb94e/Se67t/JXSZuZw26WC4fdLBMOu1kmHHazTDjsZplw2M0y0eun3ixvccbkZH3t9PSUzZ+evKFirbdx9N7cuuWUZH3fh56o6f4HGp/ZzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMeJx9gNPnTkzWX7gqPUvPnWcsTtbPHJ7+THkttkdXsv74lgnpO+j9282z4jO7WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpaJXsfZJY0H7gHGUJqieWFELJA0Gvg+cBSlaZunR8Rb/ddqvgYf9YfJ+kuzjqhYm3/xkuS+fzZyc1U91cO8zrZk/bFvnZasH3RP+nvn7aP6cmbfAcyNiBOA04ArJZ0AXAMsj4hJwPLitpm1qF7DHhEdEfFksbwNeA4YB1wA7Lq8ajFwYX81aWa126PX7JKOAk4BVgBjInZfj/g6paf5Ztai+hx2SSOBB4CvRsTW8lqUJozrcdI4SbMltUtq72J7Tc2aWfX6FHZJQygF/d6I+EGxulPS2KI+FtjU074RsTAi2iKibQjpD12YWf/pNeySBNwFPBcRN5eVlgEzi+WZwEP1b8/M6qUvH3E9A/gysFrSqmLdPOAGYKmkS4GXgen90+Leb/CR45P1rZ87PFmf/q8/SdavOOjBPe6pXuZ2pIfHfn1b5eG10Xc/ntz3oPDQWj31GvaI+BXQ43zPgCdbN9tL+Ao6s0w47GaZcNjNMuGwm2XCYTfLhMNulgl/lXQfDf6Dypf+b/nOyOS+X5nwi2R9xv6dVfVUD3M2Tk3Wf3t7esrmg+9/Olkf/Z7HyluFz+xmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSayGWfvmpb+2uLtX9uSrM87+uGKtWn7vldVT/XS2f1BxdqZy+Ym9z1u3nPJ+qit6XHyncmqtRKf2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTGQzzr7+ovS/ay+cdH+/Hfu2tycm6wt+MS1ZV3elb/IuOe4b6yrWJr2xIrlvd7JqA4nP7GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhQR6Q2k8cA9wBgggIURsUDSfOBvgTeKTedFROUPfQMHaHScKs/ybNZfVsRytsaWHi/M6MtFNTuAuRHxpKT9gZWSHilqt0TETfVq1Mz6T69hj4gOoKNY3ibpOWBcfzdmZvW1R6/ZJR0FnALsugZzjqSnJS2SNKrCPrMltUtq72J7Tc2aWfX6HHZJI4EHgK9GxFbgDmAiMJnSmf+bPe0XEQsjoi0i2oYwrA4tm1k1+hR2SUMoBf3eiPgBQER0RkR3ROwE7gSm9F+bZlarXsMuScBdwHMRcXPZ+rFlm10ErKl/e2ZWL315N/4M4MvAakmrinXzgBmSJlMajtsAXN4vHZpZXfTl3fhfAT2N2yXH1M2stfgKOrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpaJXr9Kuq4Hk94AXi5bdQiwuWEN7JlW7a1V+wL3Vq169nZkRBzaU6GhYf/EwaX2iGhrWgMJrdpbq/YF7q1ajerNT+PNMuGwm2Wi2WFf2OTjp7Rqb63aF7i3ajWkt6a+Zjezxmn2md3MGsRhN8tEU8Iu6RxJz0taK+maZvRQiaQNklZLWiWpvcm9LJK0SdKasnWjJT0i6cXid49z7DWpt/mSNhaP3SpJ5zWpt/GSHpX0rKRnJF1VrG/qY5foqyGPW8Nfs0saBLwAnA28CjwBzIiIZxvaSAWSNgBtEdH0CzAknQm8C9wTEZ8u1v07sCUibij+oRwVEVe3SG/zgXebPY13MVvR2PJpxoELgUto4mOX6Gs6DXjcmnFmnwKsjYh1EfEh8D3ggib00fIi4jFgy8dWXwAsLpYXU/qfpeEq9NYSIqIjIp4slrcBu6YZb+pjl+irIZoR9nHAK2W3X6W15nsP4GeSVkqa3exmejAmIjqK5deBMc1spge9TuPdSB+bZrxlHrtqpj+vld+g+6SpEfFZ4FzgyuLpakuK0muwVho77dM03o3SwzTjuzXzsat2+vNaNSPsG4HxZbePKNa1hIjYWPzeBDxI601F3blrBt3i96Ym97NbK03j3dM047TAY9fM6c+bEfYngEmSJkgaClwMLGtCH58gaUTxxgmSRgDTaL2pqJcBM4vlmcBDTezlI1plGu9K04zT5Meu6dOfR0TDf4DzKL0j/xLwz83ooUJfnwKeKn6eaXZvwBJKT+u6KL23cSlwMLAceBH4OTC6hXr7LrAaeJpSsMY2qbeplJ6iPw2sKn7Oa/Zjl+irIY+bL5c1y4TfoDPLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMvH/TswJIRNpLrYAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"#声明数据读取函数,从训练集中读取数据\n",
"paddle.enable_imperative()\n",
"train_loader = load_data_new('train')\n",
"for batch_id, data in enumerate(train_loader()):\n",
"\n",
" image_data, label_data = data[0], data[1] \n",
" if batch_id == 0:\n",
" # 打印数据shape和类型\n",
" print(image_data.shape, label_data.shape, type(image_data), type(label_data))\n",
" print(\"\\n打印第一个batch的第一个图像,对应标签数字为{}\".format(label_data[0].numpy()))\n",
" # 原始数据是归一化后的数据,因此这里需要反归一化\n",
" img = np.array(image_data[0]+1)*127.5\n",
" img = np.reshape(img, [28, 28]).astype(np.uint8)\n",
" plt.figure(\"Image\") # 图像窗口名称\n",
" plt.imshow(img)\n",
" plt.axis('on') # 关掉坐标轴为 off\n",
" plt.title('image') # 图像题目\n",
" plt.show()\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. 组成网络"
]
},
{
"cell_type": "code",
"execution_count": 301,
"metadata": {},
"outputs": [],
"source": [
"from paddle.nn import Conv2D, Pool2D, Linear\n",
"#定义网络结构,这里使用最简单的线性网络\n",
"class Mnist(paddle.nn.Layer):\n",
" def __init__(self, name_scope):\n",
" super(Mnist, self).__init__()\n",
" self.fc = Linear(input_dim=784, output_dim=10, act='softmax', dtype='float64')\n",
"\n",
" # 定义网络结构的前向计算过程\n",
" def forward(self, inputs,label=None):\n",
" outputs = self.fc(inputs)\n",
" if label is not None:\n",
" acc = paddle.metric.accuracy(input=outputs, label=label)\n",
" return outputs, acc\n",
" else:\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. 训练模型\n",
"在训练模型前,需要设置模型的运行环境,这里我们设置模型在cpu上运行,并将其设置为动态图模式。"
]
},
{
"cell_type": "code",
"execution_count": 302,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, batch: 0, loss is: [3.13129952]\n",
"epoch: 0, batch: 100, loss is: [2.82868815]\n",
"epoch: 0, batch: 200, loss is: [2.5842488]\n",
"epoch: 0, batch: 300, loss is: [3.21580688]\n",
"epoch: 0, batch: 400, loss is: [3.03717391]\n",
"epoch: 0, batch: 500, loss is: [2.84022745]\n",
"epoch: 0, batch: 600, loss is: [2.85783756]\n",
"epoch: 0, batch: 700, loss is: [2.76853633]\n",
"epoch: 1, batch: 0, loss is: [3.13129952]\n",
"epoch: 1, batch: 100, loss is: [2.82868815]\n",
"epoch: 1, batch: 200, loss is: [2.5842488]\n",
"epoch: 1, batch: 300, loss is: [3.21580688]\n",
"epoch: 1, batch: 400, loss is: [3.03717391]\n",
"epoch: 1, batch: 500, loss is: [2.84022745]\n",
"epoch: 1, batch: 600, loss is: [2.85783756]\n",
"epoch: 1, batch: 700, loss is: [2.76853633]\n",
"epoch: 2, batch: 0, loss is: [3.13129952]\n",
"epoch: 2, batch: 100, loss is: [2.82868815]\n",
"epoch: 2, batch: 200, loss is: [2.5842488]\n",
"epoch: 2, batch: 300, loss is: [3.21580688]\n",
"epoch: 2, batch: 400, loss is: [3.03717391]\n",
"epoch: 2, batch: 500, loss is: [2.84022745]\n",
"epoch: 2, batch: 600, loss is: [2.85783756]\n",
"epoch: 2, batch: 700, loss is: [2.76853633]\n",
"epoch: 3, batch: 0, loss is: [3.13129952]\n",
"epoch: 3, batch: 100, loss is: [2.82868815]\n",
"epoch: 3, batch: 200, loss is: [2.5842488]\n",
"epoch: 3, batch: 300, loss is: [3.21580688]\n",
"epoch: 3, batch: 400, loss is: [3.03717391]\n",
"epoch: 3, batch: 500, loss is: [2.84022745]\n",
"epoch: 3, batch: 600, loss is: [2.85783756]\n",
"epoch: 3, batch: 700, loss is: [2.76853633]\n",
"epoch: 4, batch: 0, loss is: [3.13129952]\n",
"epoch: 4, batch: 100, loss is: [2.82868815]\n",
"epoch: 4, batch: 200, loss is: [2.5842488]\n",
"epoch: 4, batch: 300, loss is: [3.21580688]\n",
"epoch: 4, batch: 400, loss is: [3.03717391]\n",
"epoch: 4, batch: 500, loss is: [2.84022745]\n",
"epoch: 4, batch: 600, loss is: [2.85783756]\n",
"epoch: 4, batch: 700, loss is: [2.76853633]\n",
"epoch: 5, batch: 0, loss is: [3.13129952]\n",
"epoch: 5, batch: 100, loss is: [2.82868815]\n",
"epoch: 5, batch: 200, loss is: [2.5842488]\n",
"epoch: 5, batch: 300, loss is: [3.21580688]\n",
"epoch: 5, batch: 400, loss is: [3.03717391]\n",
"epoch: 5, batch: 500, loss is: [2.84022745]\n",
"epoch: 5, batch: 600, loss is: [2.85783756]\n",
"epoch: 5, batch: 700, loss is: [2.76853633]\n",
"epoch: 6, batch: 0, loss is: [3.13129952]\n",
"epoch: 6, batch: 100, loss is: [2.82868815]\n",
"epoch: 6, batch: 200, loss is: [2.5842488]\n",
"epoch: 6, batch: 300, loss is: [3.21580688]\n",
"epoch: 6, batch: 400, loss is: [3.03717391]\n",
"epoch: 6, batch: 500, loss is: [2.84022745]\n",
"epoch: 6, batch: 600, loss is: [2.85783756]\n",
"epoch: 6, batch: 700, loss is: [2.76853633]\n",
"epoch: 7, batch: 0, loss is: [3.13129952]\n",
"epoch: 7, batch: 100, loss is: [2.82868815]\n",
"epoch: 7, batch: 200, loss is: [2.5842488]\n",
"epoch: 7, batch: 300, loss is: [3.21580688]\n",
"epoch: 7, batch: 400, loss is: [3.03717391]\n",
"epoch: 7, batch: 500, loss is: [2.84022745]\n",
"epoch: 7, batch: 600, loss is: [2.85783756]\n",
"epoch: 7, batch: 700, loss is: [2.76853633]\n",
"epoch: 8, batch: 0, loss is: [3.13129952]\n",
"epoch: 8, batch: 100, loss is: [2.82868815]\n",
"epoch: 8, batch: 200, loss is: [2.5842488]\n",
"epoch: 8, batch: 300, loss is: [3.21580688]\n",
"epoch: 8, batch: 400, loss is: [3.03717391]\n",
"epoch: 8, batch: 500, loss is: [2.84022745]\n",
"epoch: 8, batch: 600, loss is: [2.85783756]\n",
"epoch: 8, batch: 700, loss is: [2.76853633]\n",
"epoch: 9, batch: 0, loss is: [3.13129952]\n",
"epoch: 9, batch: 100, loss is: [2.82868815]\n",
"epoch: 9, batch: 200, loss is: [2.5842488]\n",
"epoch: 9, batch: 300, loss is: [3.21580688]\n",
"epoch: 9, batch: 400, loss is: [3.03717391]\n",
"epoch: 9, batch: 500, loss is: [2.84022745]\n",
"epoch: 9, batch: 600, loss is: [2.85783756]\n",
"epoch: 9, batch: 700, loss is: [2.76853633]\n"
]
}
],
"source": [
"# 定义MNIST类的对象,以及优化器\n",
"mnist = Mnist(\"mnist\")\n",
"\n",
"# 定义优化器\n",
"optimizer = paddle.optimizer.Adam(learning_rate=0.1,parameter_list=mnist.parameters())\n",
"\n",
"EPOCH_NUM = 10\n",
"for epoch_id in range(EPOCH_NUM):\n",
" for batch_id, data in enumerate(train_loader()):\n",
" #准备数据\n",
" image_data, label_data = data[0], data[1]\n",
"\n",
" #前向计算的过程\n",
" predict = mnist(image_data)\n",
"\n",
" #计算损失,取一个批次样本损失的平均值\n",
" loss = paddle.nn.functional.cross_entropy(predict,label_data)\n",
" avg_loss = paddle.mean(loss)\n",
"\n",
" #每训练了100批次的数据,打印下当前Loss的情况\n",
" if batch_id % 100 == 0:\n",
" print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n",
"\n",
" #后向传播,更新参数的过程\n",
" avg_loss.backward()\n",
" optimizer.minimize(avg_loss)\n",
" mnist.clear_gradients()\n",
"\n",
"#保存模型参数\n",
"model_dict = mnist.state_dict()\n",
"paddle.imperative.save(model_dict, \"save_temp\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. 评估测试"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"mnist_eval = Mnist(\"mnist\") \n",
"model_dict, _ = paddle.imperative.load(\"save_temp\")\n",
"mnist_eval.load_dict(model_dict)\n",
"\n",
"#切换到评估模式\n",
"mnist_eval.eval()\n",
"\n",
"acc_set = []\n",
"avg_loss_set = []\n",
"\n",
"# 定义数据加载器\n",
"test_loader = load_data_new('eval')\n",
"for batch_id, data in enumerate(test_loader()):\n",
" image_data, label_data = data[0],data[1]\n",
" label_data = paddle.reshape(label_data,[-1,1])\n",
" \n",
" #前向计算的过程\n",
" predict, acc = mnist_eval(image_data, label_data)\n",
"\n",
" #计算损失,取一个批次样本损失的平均值\n",
" loss = paddle.nn.functional.cross_entropy(predict,label_data)\n",
" avg_loss = paddle.mean(loss)\n",
" acc_set.append(float(acc.numpy()))\n",
" avg_loss_set.append(float(avg_loss.numpy()))\n",
" \n",
"acc_val_mean = np.array(acc_set).mean()\n",
"avg_loss_val_mean = np.array(avg_loss_set).mean()\n",
"print(\"Eval avg_loss is: {}, acc is: {}\".format(avg_loss_val_mean, acc_val_mean))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册