save_model.ipynb 15.1 KB
Notebook
Newer Older
D
dingjiaweiww 已提交
1 2 3 4 5 6 7
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型保存及加载\n",
D
dingjiaweiww 已提交
8
    "本教程将基于Paddle高阶API对模型参数的保存和加载进行讲解。在日常训练模型过程中我们会遇到一些突发情况,导致训练过程主动或被动的中断,因此在模型没有完全训练好的情况下,我们需要高频的保存下模型参数,在发生意外时可以快速载入保存的参数继续训练。抑或是模型已经训练好了,我们需要使用训练好的参数进行预测或部署模型上线。面对上述情况,Paddle中提供了保存模型和提取模型的方法,支持从上一次保存状态开始训练,只要我们随时保存训练过程中的模型状态,就不用从初始状态重新训练。\n",
D
dingjiaweiww 已提交
9
    "下面将基于手写数字识别的模型讲解paddle如何保存及加载模型,并恢复训练,网络结构部分的讲解省略。"
D
dingjiaweiww 已提交
10 11 12 13 14 15 16
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境\n",
17
    "本教程基于paddle-2.0Beta版本编写,如果您的环境不是本版本,请先安装paddle-2.0Beta版本。"
D
dingjiaweiww 已提交
18 19 20 21
   ]
  },
  {
   "cell_type": "code",
D
dingjiaweiww 已提交
22
   "execution_count": 9,
D
dingjiaweiww 已提交
23 24 25 26 27 28 29 30 31 32 33 34
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0.0\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
D
dingjiaweiww 已提交
35
    "import paddle.nn.functional as F\n",
D
dingjiaweiww 已提交
36
    "from paddle.nn import Layer\n",
D
dingjiaweiww 已提交
37 38
    "from paddle.vision.datasets import MNIST\n",
    "from paddle.metric import Accuracy\n",
D
dingjiaweiww 已提交
39
    "from paddle.nn import Conv2d,MaxPool2d,Linear\n",
D
dingjiaweiww 已提交
40 41 42 43 44 45 46 47 48 49 50 51
    "from paddle.static import InputSpec\n",
    "\n",
    "print(paddle.__version__)\n",
    "paddle.disable_static()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集\n",
    "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/\n",
D
dingjiaweiww 已提交
52
    "本例中我们使用飞桨自带的mnist数据集。使用from paddle.vision.datasets import MNIST 引入即可。"
D
dingjiaweiww 已提交
53 54 55 56
   ]
  },
  {
   "cell_type": "code",
D
dingjiaweiww 已提交
57
   "execution_count": 4,
D
dingjiaweiww 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = MNIST(mode='train')\n",
    "test_dataset = MNIST(mode='test')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型搭建"
   ]
  },
  {
   "cell_type": "code",
D
dingjiaweiww 已提交
74
   "execution_count": 10,
D
dingjiaweiww 已提交
75 76 77
   "metadata": {},
   "outputs": [],
   "source": [
D
dingjiaweiww 已提交
78
    "class MyModel(Layer):\n",
D
dingjiaweiww 已提交
79 80 81
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
D
dingjiaweiww 已提交
82
    "        self.max_pool1 = MaxPool2d(kernel_size=2, stride=2)\n",
D
dingjiaweiww 已提交
83
    "        self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n",
D
dingjiaweiww 已提交
84
    "        self.max_pool2 = MaxPool2d(kernel_size=2, stride=2)\n",
D
dingjiaweiww 已提交
85 86 87 88 89 90
    "        self.linear1 = Linear(in_features=16*5*5, out_features=120)\n",
    "        self.linear2 = Linear(in_features=120, out_features=84)\n",
    "        self.linear3 = Linear(in_features=84, out_features=10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
D
dingjiaweiww 已提交
91
    "        x = F.relu(x)\n",
D
dingjiaweiww 已提交
92
    "        x = self.max_pool1(x)\n",
D
dingjiaweiww 已提交
93
    "        x = F.relu(x)\n",
D
dingjiaweiww 已提交
94 95
    "        x = self.conv2(x)\n",
    "        x = self.max_pool2(x)\n",
D
dingjiaweiww 已提交
96
    "        x = paddle.flatten(x, start_axis=1, stop_axis=-1)\n",
D
dingjiaweiww 已提交
97
    "        x = self.linear1(x)\n",
D
dingjiaweiww 已提交
98
    "        x = F.relu(x)\n",
D
dingjiaweiww 已提交
99
    "        x = self.linear2(x)\n",
D
dingjiaweiww 已提交
100
    "        x = F.relu(x)\n",
D
dingjiaweiww 已提交
101
    "        x = self.linear3(x)\n",
D
dingjiaweiww 已提交
102
    "        x = F.softmax(x)\n",
D
dingjiaweiww 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练\n",
    "通过`Model` 构建实例,快速完成模型训练"
   ]
  },
  {
   "cell_type": "code",
D
dingjiaweiww 已提交
116
   "execution_count": 11,
D
dingjiaweiww 已提交
117 118 119 120 121 122 123
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
D
dingjiaweiww 已提交
124 125 126 127 128 129 130 131 132 133 134
      "step 100/938 - loss: 1.6177 - acc_top1: 0.6119 - acc_top2: 0.6813 - 15ms/step\n",
      "step 200/938 - loss: 1.7720 - acc_top1: 0.7230 - acc_top2: 0.7788 - 15ms/step\n",
      "step 300/938 - loss: 1.6114 - acc_top1: 0.7666 - acc_top2: 0.8164 - 15ms/step\n",
      "step 400/938 - loss: 1.6537 - acc_top1: 0.7890 - acc_top2: 0.8350 - 15ms/step\n",
      "step 500/938 - loss: 1.5229 - acc_top1: 0.8170 - acc_top2: 0.8619 - 15ms/step\n",
      "step 600/938 - loss: 1.5269 - acc_top1: 0.8391 - acc_top2: 0.8821 - 15ms/step\n",
      "step 700/938 - loss: 1.4821 - acc_top1: 0.8561 - acc_top2: 0.8970 - 15ms/step\n",
      "step 800/938 - loss: 1.4860 - acc_top1: 0.8689 - acc_top2: 0.9081 - 15ms/step\n",
      "step 900/938 - loss: 1.5032 - acc_top1: 0.8799 - acc_top2: 0.9174 - 15ms/step\n",
      "step 938/938 - loss: 1.4617 - acc_top1: 0.8835 - acc_top2: 0.9203 - 15ms/step\n",
      "save checkpoint at /Users/dingjiawei/online_repo/book/paddle2.0_docs/save_model/mnist_checkpoint/0\n",
D
dingjiaweiww 已提交
135
      "Eval begin...\n",
D
dingjiaweiww 已提交
136 137
      "step 100/157 - loss: 1.4765 - acc_top1: 0.9636 - acc_top2: 0.9891 - 6ms/step\n",
      "step 157/157 - loss: 1.4612 - acc_top1: 0.9705 - acc_top2: 0.9910 - 6ms/step\n",
D
dingjiaweiww 已提交
138
      "Eval samples: 10000\n",
D
dingjiaweiww 已提交
139
      "save checkpoint at /Users/dingjiawei/online_repo/book/paddle2.0_docs/save_model/mnist_checkpoint/final\n"
D
dingjiaweiww 已提交
140 141 142 143 144 145
     ]
    }
   ],
   "source": [
    "inputs = InputSpec([None, 784], 'float32', 'x')\n",
    "labels = InputSpec([None, 10], 'float32', 'x')\n",
D
dingjiaweiww 已提交
146
    "model = paddle.Model(MyModel(), inputs, labels)\n",
D
dingjiaweiww 已提交
147 148 149 150 151 152 153 154 155 156 157
    "\n",
    "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
    "\n",
    "model.prepare(\n",
    "    optim,\n",
    "    paddle.nn.loss.CrossEntropyLoss(),\n",
    "    Accuracy(topk=(1, 2))\n",
    "    )\n",
    "model.fit(train_dataset,\n",
    "        test_dataset,\n",
    "        epochs=1,\n",
D
dingjiaweiww 已提交
158
    "        log_freq=100,\n",
D
dingjiaweiww 已提交
159 160 161 162 163 164 165 166
    "        batch_size=64,\n",
    "        save_dir='mnist_checkpoint')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dingjiaweiww 已提交
167
    "## 保存模型参数\n",
D
dingjiaweiww 已提交
168
    "\n",
D
dingjiaweiww 已提交
169 170
    "目前Paddle框架有三种保存模型参数的体系,分别是:\n",
    "#### paddle 高阶API-模型参数保存\n",
D
dingjiaweiww 已提交
171 172
    "    * paddle.Model.fit\n",
    "    * paddle.Model.save\n",
D
dingjiaweiww 已提交
173 174 175 176 177 178 179
    "#### paddle 基础框架-动态图-模型参数保存 \n",
    "    * paddle.save\n",
    "#### paddle 基础框架-静态图-模型参数保存 \n",
    "    * paddle.io.save\n",
    "    * paddle.io.save_inference_model\n",
    "\n",
    "下面将基于高阶API对模型保存与加载的方法进行讲解。"
D
dingjiaweiww 已提交
180 181 182
   ]
  },
  {
D
dingjiaweiww 已提交
183
   "cell_type": "markdown",
D
dingjiaweiww 已提交
184 185
   "metadata": {},
   "source": [
D
dingjiaweiww 已提交
186
    "\n",
D
dingjiaweiww 已提交
187 188 189 190 191
    "#### 方法一:\n",
    "* paddle.Model.fit(train_data, epochs, batch_size, save_dir, log_freq) <br><br>\n",
    "在使用model.fit函数进行网络循环训练时,在save_dir参数中指定保存模型的路径,save_freq指定写入频率,即可同时实现模型的训练和保存。mode.fit()只能保存模型参数,不能保存优化器参数,每个epoch结束只会生成一个.pdparams文件。可以边训练边保存,每次epoch结束会实时生成一个.pdparams文件。 \n",
    "\n",
    "#### 方法二:\n",
D
dingjiaweiww 已提交
192 193
    "* paddle.Model.save(self, path, training=True) <br><br>\n",
    "model.save(path)方法可以保存模型结构、网络参数和优化器参数,参数training=true的使用场景是在训练过程中,此时会保存网络参数和优化器参数。每个epoch生成两种文件 0.pdparams,0.pdopt,分别存储了模型参数和优化器参数,但是只会在整个模型训练完成后才会生成包含所有epoch参数的文件,path的格式为'dirname/file_prefix' 或 'file_prefix',其中dirname指定路径名称,file_prefix 指定参数文件的名称。当training=false的时候,代表已经训练结束,此时存储的是预测模型结构和网络参数。"
D
dingjiaweiww 已提交
194 195 196 197 198 199 200 201
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
D
dingjiaweiww 已提交
202
    "# 方法一:训练过程中实时保存每个epoch的模型参数\n",
D
dingjiaweiww 已提交
203 204 205 206 207 208 209 210
    "model.fit(train_dataset,\n",
    "        test_dataset,\n",
    "        epochs=2,\n",
    "        batch_size=64,\n",
    "        save_dir='mnist_checkpoint'\n",
    "        )"
   ]
  },
D
dingjiaweiww 已提交
211 212 213 214 215 216 217 218 219 220
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 方法二:model.save()保存模型和优化器参数信息\n",
    "model.save('mnist_checkpoint/test')"
   ]
  },
D
dingjiaweiww 已提交
221 222 223 224 225 226 227
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型参数\n",
    "\n",
    "当恢复训练状态时,需要加载模型数据,此时我们可以使用加载函数从存储模型状态和优化器状态的文件中载入模型参数和优化器参数,如果不需要恢复优化器,则不必使用优化器状态文件。\n",
D
dingjiaweiww 已提交
228 229 230 231 232 233 234
    "#### 高阶API-模型参数加载\n",
    "    * paddle.Model.load\n",
    "#### paddle 基础框架-动态图-模型参数加载\n",
    "    * paddle.load\n",
    "#### paddle 基础框架-静态图-模型参数加载\n",
    "    * paddle.io.load \n",
    "    * paddle.io.load_inference_model"
D
dingjiaweiww 已提交
235 236 237 238 239 240
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dingjiaweiww 已提交
241
    "下面将对高阶API的模型参数加载方法进行讲解\n",
D
dingjiaweiww 已提交
242 243
    "* model.load(self, path, skip_mismatch=False, reset_optimizer=False)<br><br>\n",
    "model.load能够同时加载模型和优化器参数。通过reset_optimizer参数来指定是否需要恢复优化器参数,若reset_optimizer参数为True,则重新初始化优化器参数,若reset_optimizer参数为False,则从路径中恢复优化器参数。"
D
dingjiaweiww 已提交
244 245 246 247 248 249 250 251
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
D
dingjiaweiww 已提交
252
    "# 高阶API加载模型\n",
D
dingjiaweiww 已提交
253 254 255 256 257 258 259
    "model.load('mnist_checkpoint/test')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dingjiaweiww 已提交
260
    "## 恢复训练\n",
D
dingjiaweiww 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274
    "\n",
    "理想的恢复训练是模型状态回到训练中断的时刻,恢复训练之后的梯度更新走向是和恢复训练前的梯度走向完全相同的。基于此,我们可以通过恢复训练后的损失变化,判断上述方法是否能准确的恢复训练。即从epoch 0结束时保存的模型参数和优化器状态恢复训练,校验其后训练的损失变化(epoch 1)是否和不中断时的训练完全一致。\n",
    "\n",
    "说明:\n",
    "\n",
    "恢复训练有如下两个要点:\n",
    "\n",
    "* 保存模型时同时保存模型参数和优化器参数\n",
    "\n",
    "* 恢复参数时同时恢复模型参数和优化器参数。"
   ]
  },
  {
   "cell_type": "code",
D
dingjiaweiww 已提交
275
   "execution_count": 12,
D
dingjiaweiww 已提交
276
   "metadata": {},
D
dingjiaweiww 已提交
277 278 279 280 281 282
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
D
dingjiaweiww 已提交
283 284 285 286 287 288 289 290 291 292
      "step 100/938 - loss: 1.4635 - acc_top1: 0.9650 - acc_top2: 0.9898 - 15ms/step\n",
      "step 200/938 - loss: 1.5459 - acc_top1: 0.9659 - acc_top2: 0.9897 - 15ms/step\n",
      "step 300/938 - loss: 1.5109 - acc_top1: 0.9658 - acc_top2: 0.9893 - 15ms/step\n",
      "step 400/938 - loss: 1.4797 - acc_top1: 0.9664 - acc_top2: 0.9899 - 15ms/step\n",
      "step 500/938 - loss: 1.4786 - acc_top1: 0.9673 - acc_top2: 0.9902 - 15ms/step\n",
      "step 600/938 - loss: 1.5082 - acc_top1: 0.9679 - acc_top2: 0.9906 - 15ms/step\n",
      "step 700/938 - loss: 1.4768 - acc_top1: 0.9687 - acc_top2: 0.9909 - 15ms/step\n",
      "step 800/938 - loss: 1.4638 - acc_top1: 0.9696 - acc_top2: 0.9913 - 15ms/step\n",
      "step 900/938 - loss: 1.5058 - acc_top1: 0.9704 - acc_top2: 0.9916 - 15ms/step\n",
      "step 938/938 - loss: 1.4702 - acc_top1: 0.9708 - acc_top2: 0.9917 - 15ms/step\n",
D
dingjiaweiww 已提交
293
      "Eval begin...\n",
D
dingjiaweiww 已提交
294 295
      "step 100/157 - loss: 1.4613 - acc_top1: 0.9755 - acc_top2: 0.9944 - 5ms/step\n",
      "step 157/157 - loss: 1.4612 - acc_top1: 0.9805 - acc_top2: 0.9956 - 5ms/step\n",
D
dingjiaweiww 已提交
296 297
      "Eval samples: 10000\n",
      "Epoch 2/2\n",
D
dingjiaweiww 已提交
298 299 300 301 302 303 304 305 306 307
      "step 100/938 - loss: 1.4832 - acc_top1: 0.9789 - acc_top2: 0.9927 - 15ms/step\n",
      "step 200/938 - loss: 1.4618 - acc_top1: 0.9779 - acc_top2: 0.9932 - 14ms/step\n",
      "step 300/938 - loss: 1.4613 - acc_top1: 0.9779 - acc_top2: 0.9929 - 15ms/step\n",
      "step 400/938 - loss: 1.4765 - acc_top1: 0.9772 - acc_top2: 0.9932 - 15ms/step\n",
      "step 500/938 - loss: 1.4932 - acc_top1: 0.9775 - acc_top2: 0.9934 - 15ms/step\n",
      "step 600/938 - loss: 1.4773 - acc_top1: 0.9773 - acc_top2: 0.9936 - 15ms/step\n",
      "step 700/938 - loss: 1.4612 - acc_top1: 0.9783 - acc_top2: 0.9939 - 15ms/step\n",
      "step 800/938 - loss: 1.4653 - acc_top1: 0.9779 - acc_top2: 0.9939 - 15ms/step\n",
      "step 900/938 - loss: 1.4639 - acc_top1: 0.9780 - acc_top2: 0.9939 - 15ms/step\n",
      "step 938/938 - loss: 1.4678 - acc_top1: 0.9779 - acc_top2: 0.9937 - 15ms/step\n",
D
dingjiaweiww 已提交
308
      "Eval begin...\n",
D
dingjiaweiww 已提交
309 310
      "step 100/157 - loss: 1.4612 - acc_top1: 0.9733 - acc_top2: 0.9945 - 6ms/step\n",
      "step 157/157 - loss: 1.4612 - acc_top1: 0.9778 - acc_top2: 0.9952 - 6ms/step\n",
D
dingjiaweiww 已提交
311 312 313 314
      "Eval samples: 10000\n"
     ]
    }
   ],
D
dingjiaweiww 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328
   "source": [
    "import paddle\n",
    "from paddle.vision.datasets import MNIST\n",
    "from paddle.metric import Accuracy\n",
    "from paddle.static import InputSpec\n",
    "#\n",
    "#\n",
    "train_dataset = MNIST(mode='train')\n",
    "test_dataset = MNIST(mode='test')\n",
    "\n",
    "paddle.disable_static()\n",
    "\n",
    "inputs = InputSpec([None, 784], 'float32', 'x')\n",
    "labels = InputSpec([None, 10], 'float32', 'x')\n",
D
dingjiaweiww 已提交
329
    "model = paddle.Model(MyModel(), inputs, labels)\n",
D
dingjiaweiww 已提交
330
    "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
D
dingjiaweiww 已提交
331
    "model.load(\"./mnist_checkpoint/final\")\n",
D
dingjiaweiww 已提交
332 333 334 335 336 337 338 339
    "model.prepare( \n",
    "      optim,\n",
    "      paddle.nn.loss.CrossEntropyLoss(),\n",
    "      Accuracy(topk=(1, 2))\n",
    "      )\n",
    "model.fit(train_data=train_dataset,\n",
    "        eval_data=test_dataset,\n",
    "        batch_size=64,\n",
D
dingjiaweiww 已提交
340 341
    "        log_freq=100,\n",
    "        epochs=2\n",
D
dingjiaweiww 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
D
dingjiaweiww 已提交
356
    "以上就是用Mnist手写数字识别的例子对保存模型、加载模型、恢复训练进行讲解,Paddle提供了很多保存和加载的API方法,您可以根据自己的需求进行选择。"
D
dingjiaweiww 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}