Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
f60ccb48
B
book
项目概览
PaddlePaddle
/
book
通知
17
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“585c0cee3a6b6734766dcaf1899b34e9c66ed784”上不存在“micro/include/utils/macros.h”
提交
f60ccb48
编写于
9月 08, 2020
作者:
D
dingjiaweiww
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update api
上级
484da788
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
120 addition
and
186 deletion
+120
-186
paddle2.0_docs/save_model/save_model.ipynb
paddle2.0_docs/save_model/save_model.ipynb
+120
-186
未找到文件。
paddle2.0_docs/save_model/save_model.ipynb
浏览文件 @
f60ccb48
...
...
@@ -5,8 +5,8 @@
"metadata": {},
"source": [
"# 模型保存及加载\n",
"本示例教程演示如何
在MNIST数据集进行模型保存和加载,在训练完神经网络模型的时候,我们往往需要保存模型参数,以便做预测的时候能够省略训练步骤,直接加载模型参数。除此之外,在日常训练工作中我们会遇到一些突发情况,导致训练过程主动或被动的中断;抑或由于模型过于庞大,训练一个模型需要花费几天的训练时间;面对以上情况,Paddle中提供了很好地
保存模型和提取模型的方法,支持从上一次保存状态开始训练,只要我们随时保存训练过程中的模型状态,就不用从初始状态重新训练。\n",
"下面将基于
MNIST
模型讲解paddle如何保存及加载模型,并恢复训练,网络结构部分的讲解省略。"
"本示例教程演示如何
利用MNIST数据集搭建模型,并在模型完成后进行模型保存和加载。在日常训练模型过程中我们会遇到一些突发情况,导致训练过程主动或被动的中断,因此在模型没有完全训练好的情况下,我们需要高频的保存下模型参数,在发生意外时可以快速载入保存的参数继续训练。抑或是模型已经训练好了,我们需要使用训练好的参数进行预测或部署模型上线。面对上述情况,Paddle中提供了
保存模型和提取模型的方法,支持从上一次保存状态开始训练,只要我们随时保存训练过程中的模型状态,就不用从初始状态重新训练。\n",
"下面将基于
手写数字识别的
模型讲解paddle如何保存及加载模型,并恢复训练,网络结构部分的讲解省略。"
]
},
{
...
...
@@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count":
8
,
"execution_count":
17
,
"metadata": {},
"outputs": [
{
...
...
@@ -32,14 +32,11 @@
],
"source": [
"import paddle\n",
"import paddle.fluid as fluid\n",
"import paddle.hapi as hapi\n",
"from paddle.nn import Layer\n",
"from paddle.nn import functional\n",
"from paddle.hapi.model import Model\n",
"from paddle.vision.datasets import MNIST\n",
"from paddle.metric import Accuracy\n",
"from paddle.nn import Conv2d,Pool2D,Linear\n",
"from paddle.vision import LeNet\n",
"from paddle.static import InputSpec\n",
"\n",
"print(paddle.__version__)\n",
...
...
@@ -52,12 +49,12 @@
"source": [
"## 数据集\n",
"手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/\n",
"本例中我们使用飞桨自带的
paddle.dataset完成mnist数据集的加载
。"
"本例中我们使用飞桨自带的
mnist数据集。使用from paddle.vision.datasets import MNIST 引入即可
。"
]
},
{
"cell_type": "code",
"execution_count":
9
,
"execution_count":
18
,
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -74,11 +71,11 @@
},
{
"cell_type": "code",
"execution_count": 1
0
,
"execution_count": 1
9
,
"metadata": {},
"outputs": [],
"source": [
"class MyModel(
Model
):\n",
"class MyModel(
Layer
):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n",
...
...
@@ -116,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count":
11
,
"execution_count":
20
,
"metadata": {},
"outputs": [
{
...
...
@@ -124,133 +121,29 @@
"output_type": "stream",
"text": [
"Epoch 1/1\n",
"step 10/938 - loss: 2.1389 - acc_top1: 0.2828 - acc_top2: 0.4516 - 16ms/step\n",
"step 20/938 - loss: 1.9412 - acc_top1: 0.4047 - acc_top2: 0.5195 - 14ms/step\n",
"step 30/938 - loss: 1.8458 - acc_top1: 0.4714 - acc_top2: 0.5708 - 14ms/step\n",
"step 40/938 - loss: 1.7914 - acc_top1: 0.5195 - acc_top2: 0.6133 - 13ms/step\n",
"step 50/938 - loss: 1.8215 - acc_top1: 0.5637 - acc_top2: 0.6531 - 13ms/step\n",
"step 60/938 - loss: 1.6824 - acc_top1: 0.5885 - acc_top2: 0.6729 - 13ms/step\n",
"step 70/938 - loss: 1.7004 - acc_top1: 0.6136 - acc_top2: 0.6989 - 13ms/step\n",
"step 80/938 - loss: 1.5939 - acc_top1: 0.6410 - acc_top2: 0.7264 - 13ms/step\n",
"step 90/938 - loss: 1.5671 - acc_top1: 0.6681 - acc_top2: 0.7514 - 13ms/step\n",
"step 100/938 - loss: 1.5495 - acc_top1: 0.6920 - acc_top2: 0.7736 - 13ms/step\n",
"step 110/938 - loss: 1.5818 - acc_top1: 0.7105 - acc_top2: 0.7908 - 13ms/step\n",
"step 120/938 - loss: 1.5283 - acc_top1: 0.7258 - acc_top2: 0.8046 - 13ms/step\n",
"step 130/938 - loss: 1.5800 - acc_top1: 0.7397 - acc_top2: 0.8167 - 13ms/step\n",
"step 140/938 - loss: 1.5149 - acc_top1: 0.7528 - acc_top2: 0.8280 - 13ms/step\n",
"step 150/938 - loss: 1.5687 - acc_top1: 0.7641 - acc_top2: 0.8372 - 13ms/step\n",
"step 160/938 - loss: 1.5313 - acc_top1: 0.7746 - acc_top2: 0.8458 - 13ms/step\n",
"step 170/938 - loss: 1.5448 - acc_top1: 0.7838 - acc_top2: 0.8534 - 13ms/step\n",
"step 180/938 - loss: 1.5323 - acc_top1: 0.7924 - acc_top2: 0.8602 - 13ms/step\n",
"step 190/938 - loss: 1.5727 - acc_top1: 0.7990 - acc_top2: 0.8655 - 13ms/step\n",
"step 200/938 - loss: 1.6199 - acc_top1: 0.8047 - acc_top2: 0.8708 - 13ms/step\n",
"step 210/938 - loss: 1.5250 - acc_top1: 0.8103 - acc_top2: 0.8757 - 13ms/step\n",
"step 220/938 - loss: 1.5084 - acc_top1: 0.8160 - acc_top2: 0.8798 - 13ms/step\n",
"step 230/938 - loss: 1.4965 - acc_top1: 0.8213 - acc_top2: 0.8836 - 13ms/step\n",
"step 240/938 - loss: 1.5144 - acc_top1: 0.8256 - acc_top2: 0.8870 - 13ms/step\n",
"step 250/938 - loss: 1.5099 - acc_top1: 0.8306 - acc_top2: 0.8908 - 13ms/step\n",
"step 260/938 - loss: 1.5245 - acc_top1: 0.8351 - acc_top2: 0.8943 - 13ms/step\n",
"step 270/938 - loss: 1.5202 - acc_top1: 0.8390 - acc_top2: 0.8971 - 13ms/step\n",
"step 280/938 - loss: 1.5367 - acc_top1: 0.8429 - acc_top2: 0.9001 - 13ms/step\n",
"step 290/938 - loss: 1.5806 - acc_top1: 0.8460 - acc_top2: 0.9028 - 13ms/step\n",
"step 300/938 - loss: 1.5277 - acc_top1: 0.8485 - acc_top2: 0.9052 - 13ms/step\n",
"step 310/938 - loss: 1.4983 - acc_top1: 0.8515 - acc_top2: 0.9075 - 13ms/step\n",
"step 320/938 - loss: 1.5280 - acc_top1: 0.8541 - acc_top2: 0.9099 - 13ms/step\n",
"step 330/938 - loss: 1.4935 - acc_top1: 0.8572 - acc_top2: 0.9120 - 13ms/step\n",
"step 340/938 - loss: 1.5882 - acc_top1: 0.8593 - acc_top2: 0.9141 - 13ms/step\n",
"step 350/938 - loss: 1.5212 - acc_top1: 0.8619 - acc_top2: 0.9161 - 13ms/step\n",
"step 360/938 - loss: 1.5276 - acc_top1: 0.8644 - acc_top2: 0.9179 - 13ms/step\n",
"step 370/938 - loss: 1.5412 - acc_top1: 0.8671 - acc_top2: 0.9198 - 13ms/step\n",
"step 380/938 - loss: 1.5440 - acc_top1: 0.8690 - acc_top2: 0.9215 - 13ms/step\n",
"step 390/938 - loss: 1.5285 - acc_top1: 0.8709 - acc_top2: 0.9231 - 13ms/step\n",
"step 400/938 - loss: 1.4708 - acc_top1: 0.8732 - acc_top2: 0.9247 - 13ms/step\n",
"step 410/938 - loss: 1.5053 - acc_top1: 0.8753 - acc_top2: 0.9262 - 13ms/step\n",
"step 420/938 - loss: 1.5023 - acc_top1: 0.8774 - acc_top2: 0.9276 - 13ms/step\n",
"step 430/938 - loss: 1.4816 - acc_top1: 0.8794 - acc_top2: 0.9288 - 13ms/step\n",
"step 440/938 - loss: 1.4981 - acc_top1: 0.8810 - acc_top2: 0.9301 - 13ms/step\n",
"step 450/938 - loss: 1.4943 - acc_top1: 0.8829 - acc_top2: 0.9313 - 13ms/step\n",
"step 460/938 - loss: 1.4984 - acc_top1: 0.8845 - acc_top2: 0.9325 - 13ms/step\n",
"step 470/938 - loss: 1.5182 - acc_top1: 0.8861 - acc_top2: 0.9337 - 13ms/step\n",
"step 480/938 - loss: 1.5311 - acc_top1: 0.8875 - acc_top2: 0.9348 - 13ms/step\n",
"step 490/938 - loss: 1.5287 - acc_top1: 0.8888 - acc_top2: 0.9359 - 13ms/step\n",
"step 500/938 - loss: 1.5035 - acc_top1: 0.8903 - acc_top2: 0.9370 - 13ms/step\n",
"step 510/938 - loss: 1.4626 - acc_top1: 0.8914 - acc_top2: 0.9378 - 13ms/step\n",
"step 520/938 - loss: 1.4631 - acc_top1: 0.8928 - acc_top2: 0.9387 - 13ms/step\n",
"step 530/938 - loss: 1.4925 - acc_top1: 0.8942 - acc_top2: 0.9397 - 13ms/step\n",
"step 540/938 - loss: 1.5132 - acc_top1: 0.8954 - acc_top2: 0.9406 - 13ms/step\n",
"step 550/938 - loss: 1.5132 - acc_top1: 0.8966 - acc_top2: 0.9414 - 13ms/step\n",
"step 560/938 - loss: 1.5377 - acc_top1: 0.8976 - acc_top2: 0.9423 - 13ms/step\n",
"step 570/938 - loss: 1.5328 - acc_top1: 0.8989 - acc_top2: 0.9431 - 13ms/step\n",
"step 580/938 - loss: 1.4635 - acc_top1: 0.9000 - acc_top2: 0.9439 - 13ms/step\n",
"step 590/938 - loss: 1.4639 - acc_top1: 0.9011 - acc_top2: 0.9447 - 13ms/step\n",
"step 600/938 - loss: 1.5219 - acc_top1: 0.9022 - acc_top2: 0.9455 - 13ms/step\n",
"step 610/938 - loss: 1.4964 - acc_top1: 0.9033 - acc_top2: 0.9464 - 13ms/step\n",
"step 620/938 - loss: 1.5228 - acc_top1: 0.9044 - acc_top2: 0.9471 - 13ms/step\n",
"step 630/938 - loss: 1.4848 - acc_top1: 0.9055 - acc_top2: 0.9477 - 13ms/step\n",
"step 640/938 - loss: 1.4741 - acc_top1: 0.9064 - acc_top2: 0.9485 - 13ms/step\n",
"step 650/938 - loss: 1.5107 - acc_top1: 0.9070 - acc_top2: 0.9491 - 13ms/step\n",
"step 660/938 - loss: 1.4972 - acc_top1: 0.9081 - acc_top2: 0.9496 - 13ms/step\n",
"step 670/938 - loss: 1.4991 - acc_top1: 0.9089 - acc_top2: 0.9502 - 13ms/step\n",
"step 680/938 - loss: 1.4891 - acc_top1: 0.9096 - acc_top2: 0.9507 - 13ms/step\n",
"step 690/938 - loss: 1.4792 - acc_top1: 0.9103 - acc_top2: 0.9512 - 13ms/step\n",
"step 700/938 - loss: 1.4754 - acc_top1: 0.9111 - acc_top2: 0.9518 - 13ms/step\n",
"step 710/938 - loss: 1.5455 - acc_top1: 0.9119 - acc_top2: 0.9522 - 13ms/step\n",
"step 720/938 - loss: 1.4906 - acc_top1: 0.9128 - acc_top2: 0.9528 - 13ms/step\n",
"step 730/938 - loss: 1.5199 - acc_top1: 0.9135 - acc_top2: 0.9533 - 13ms/step\n",
"step 740/938 - loss: 1.4820 - acc_top1: 0.9141 - acc_top2: 0.9537 - 13ms/step\n",
"step 750/938 - loss: 1.5074 - acc_top1: 0.9149 - acc_top2: 0.9543 - 13ms/step\n",
"step 760/938 - loss: 1.4777 - acc_top1: 0.9156 - acc_top2: 0.9547 - 13ms/step\n",
"step 770/938 - loss: 1.5069 - acc_top1: 0.9162 - acc_top2: 0.9552 - 13ms/step\n",
"step 780/938 - loss: 1.5387 - acc_top1: 0.9169 - acc_top2: 0.9556 - 13ms/step\n",
"step 790/938 - loss: 1.4750 - acc_top1: 0.9175 - acc_top2: 0.9560 - 13ms/step\n",
"step 800/938 - loss: 1.4616 - acc_top1: 0.9182 - acc_top2: 0.9564 - 13ms/step\n",
"step 810/938 - loss: 1.5077 - acc_top1: 0.9187 - acc_top2: 0.9569 - 13ms/step\n",
"step 820/938 - loss: 1.4787 - acc_top1: 0.9194 - acc_top2: 0.9572 - 13ms/step\n",
"step 830/938 - loss: 1.4774 - acc_top1: 0.9200 - acc_top2: 0.9577 - 13ms/step\n",
"step 840/938 - loss: 1.4851 - acc_top1: 0.9206 - acc_top2: 0.9581 - 13ms/step\n",
"step 850/938 - loss: 1.4689 - acc_top1: 0.9211 - acc_top2: 0.9585 - 13ms/step\n",
"step 860/938 - loss: 1.4844 - acc_top1: 0.9216 - acc_top2: 0.9588 - 13ms/step\n",
"step 870/938 - loss: 1.5136 - acc_top1: 0.9223 - acc_top2: 0.9592 - 13ms/step\n",
"step 880/938 - loss: 1.4779 - acc_top1: 0.9229 - acc_top2: 0.9596 - 13ms/step\n",
"step 890/938 - loss: 1.4642 - acc_top1: 0.9235 - acc_top2: 0.9599 - 13ms/step\n",
"step 900/938 - loss: 1.5236 - acc_top1: 0.9238 - acc_top2: 0.9602 - 13ms/step\n",
"step 910/938 - loss: 1.4735 - acc_top1: 0.9243 - acc_top2: 0.9605 - 13ms/step\n",
"step 920/938 - loss: 1.4942 - acc_top1: 0.9249 - acc_top2: 0.9608 - 12ms/step\n",
"step 930/938 - loss: 1.5161 - acc_top1: 0.9254 - acc_top2: 0.9611 - 12ms/step\n",
"step 938/938 - loss: 1.4663 - acc_top1: 0.9259 - acc_top2: 0.9614 - 12ms/step\n",
"save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/mnist_checkpoint/0\n",
"step 100/938 - loss: 1.6498 - acc_top1: 0.6281 - acc_top2: 0.7047 - 17ms/step\n",
"step 200/938 - loss: 1.6331 - acc_top1: 0.7491 - acc_top2: 0.8125 - 17ms/step\n",
"step 300/938 - loss: 1.5107 - acc_top1: 0.8120 - acc_top2: 0.8671 - 16ms/step\n",
"step 400/938 - loss: 1.5154 - acc_top1: 0.8449 - acc_top2: 0.8954 - 16ms/step\n",
"step 500/938 - loss: 1.4851 - acc_top1: 0.8681 - acc_top2: 0.9136 - 16ms/step\n",
"step 600/938 - loss: 1.4739 - acc_top1: 0.8833 - acc_top2: 0.9261 - 16ms/step\n",
"step 700/938 - loss: 1.4688 - acc_top1: 0.8946 - acc_top2: 0.9348 - 16ms/step\n",
"step 800/938 - loss: 1.4646 - acc_top1: 0.9035 - acc_top2: 0.9416 - 16ms/step\n",
"step 900/938 - loss: 1.5141 - acc_top1: 0.9108 - acc_top2: 0.9471 - 16ms/step\n",
"step 938/938 - loss: 1.4745 - acc_top1: 0.9132 - acc_top2: 0.9489 - 16ms/step\n",
"save checkpoint at /Users/dingjiawei/Desktop/教程/mnist_checkpoint/0\n",
"Eval begin...\n",
"step 10/157 - loss: 1.4995 - acc_top1: 0.9734 - acc_top2: 0.9906 - 5ms/step\n",
"step 20/157 - loss: 1.5421 - acc_top1: 0.9602 - acc_top2: 0.9820 - 5ms/step\n",
"step 30/157 - loss: 1.5077 - acc_top1: 0.9615 - acc_top2: 0.9828 - 5ms/step\n",
"step 40/157 - loss: 1.4786 - acc_top1: 0.9586 - acc_top2: 0.9836 - 4ms/step\n",
"step 50/157 - loss: 1.4705 - acc_top1: 0.9606 - acc_top2: 0.9847 - 4ms/step\n",
"step 60/157 - loss: 1.5273 - acc_top1: 0.9602 - acc_top2: 0.9846 - 4ms/step\n",
"step 70/157 - loss: 1.4628 - acc_top1: 0.9607 - acc_top2: 0.9846 - 4ms/step\n",
"step 80/157 - loss: 1.4767 - acc_top1: 0.9617 - acc_top2: 0.9848 - 4ms/step\n",
"step 90/157 - loss: 1.4788 - acc_top1: 0.9648 - acc_top2: 0.9863 - 4ms/step\n",
"step 100/157 - loss: 1.4643 - acc_top1: 0.9661 - acc_top2: 0.9867 - 5ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 110/157 - loss: 1.4631 - acc_top1: 0.9678 - acc_top2: 0.9874 - 4ms/step\n",
"step 120/157 - loss: 1.4806 - acc_top1: 0.9695 - acc_top2: 0.9882 - 4ms/step\n",
"step 130/157 - loss: 1.4823 - acc_top1: 0.9710 - acc_top2: 0.9888 - 4ms/step\n",
"step 140/157 - loss: 1.4612 - acc_top1: 0.9730 - acc_top2: 0.9896 - 4ms/step\n",
"step 150/157 - loss: 1.4668 - acc_top1: 0.9741 - acc_top2: 0.9900 - 4ms/step\n",
"step 157/157 - loss: 1.4613 - acc_top1: 0.9742 - acc_top2: 0.9901 - 4ms/step\n",
"step 100/157 - loss: 1.4721 - acc_top1: 0.9625 - acc_top2: 0.9897 - 5ms/step\n",
"step 157/157 - loss: 1.4613 - acc_top1: 0.9701 - acc_top2: 0.9919 - 5ms/step\n",
"Eval samples: 10000\n",
"save checkpoint at /Users/
chenlong/online_repo/book/paddle2.0_docs
/mnist_checkpoint/final\n"
"save checkpoint at /Users/
dingjiawei/Desktop/教程
/mnist_checkpoint/final\n"
]
}
],
"source": [
"inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n",
"model =
hapi.Model(LeNet
(), inputs, labels)\n",
"model =
paddle.Model(MyModel
(), inputs, labels)\n",
"\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"\n",
...
...
@@ -262,6 +155,7 @@
"model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=1,\n",
" log_freq=100,\n",
" batch_size=64,\n",
" save_dir='mnist_checkpoint')\n"
]
...
...
@@ -271,35 +165,32 @@
"metadata": {},
"source": [
"## 保存模型\n",
"#### 静态图保存模型API:\n",
"* fluid.io.save_vars(executor, dirname, main_program=None, vars=None, predicate=None, filename=None)<br>\n",
"1)通过接口中的 vars 指定需要保存的变量列表。<br>\n",
"2)将一个已经存在的程序(Program)赋值给接口中的 main_program,然后这个程序中的所有变量都将被保存下来。\n",
"第一种保存方式的优先级要高于第二种。 \n",
"* fluid.io.save_params(executor, dirname, main_program=None, filename=None) <br>\n",
"通过接口中的 main_program 指定好程序(Program),该接口会将所指定程序中的全部参数(Parameter)过滤出来,并将它们保存到 dirname 指定的文件夹或 filename 指定的文件中。 \n",
"* fluid.io.save_persistables(executor, dirname, main_program=None, filename=None) <br>\n",
"通过接口中的 main_program 指定好程序(Program),该接口会将所指定程序中的全部持久性变量(persistable==True)过滤出来,并将它们保存到 dirname 指定的文件夹或 filename 指定的文件中。 \n",
"* fluid.io.save_inference_model(dirname, feeded_var_names, target_vars, executor, main_program=None, model_filename=None, params_filename=None, export_for_deployment=True, program_only=False) <br>\n",
"存储预测模型时,一般通过 fluid.io.save_inference_model 接口对默认的 fluid.Program 进行裁剪,只保留预测 predict_var 所需部分。 裁剪后的 program 会保存在指定路径 ./infer_model/__model__ 下,参数会保存到 ./infer_model 下的各个独立文件。\n",
"\n",
"#### 动态图保存模型API:\n",
"* paddle.fluid.dygraph.save_dygraph(state_dict, model_path) <br>\n",
"该接口将传入的参数或优化器的 dict 保存到磁盘上,会根据 state_dict 的内容,自动给 model_path 添加 .pdparams 或者 .pdopt 后缀, 生成 model_path + \".pdparams\" 或者 model_path + \".pdopt\" 文件,state_dict 是通过 Layer 的 state_dict() 方法得到的。详细使用方法请参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/dygraph_cn/save_dygraph_cn.html#save-dygraph \n",
"* paddle.incubate.hapi.model.Model.fit(train_data, epochs, batch_size, save_dir, log_freq) <br>\n",
"在使用model.fit函数进行网络循环训练时,在save_dir参数中指定保存模型的路径,save_freq指定写入频率,即可同时实现模型的训练和保存。mode.fit()只能保存模型参数,不能保存优化器参数,每个epoch结束只会生成一个.pdparams文件。可以边训练边保存,每次epoch结束会实时生成一个.pdparams文件。 \n",
"* paddle.incubate.hapi.model.Model.save(path) <br>\n",
"model.save(path)方法可以保存网络参数和优化器参数,每个epoch会生成两种文件 0.pdparams,0.pdopt,分别存储了模型参数和优化器参数,但是只会在整个模型训练完成后才会生成参数文件,path的格式为'dirname/file_prefix' 或 'file_prefix',其中dirname指定路径名称,file_prefix 指定参数文件的名称。"
"目前Paddle框架有三种保存与加载模型的体系,分别是:\n",
"#### paddle 高阶API-模型保存与加载\n",
" * paddle.Model.fit\n",
" * paddle.Model.save\n",
"#### paddle 基础框架-动态图-模型保存与加载\n",
" * paddle.fluid.dygraph.save_dygraph\n",
"#### paddle 基础框架-静态图-模型保存与加载\n",
" * fluid.io.save_vars\n",
" * fluid.io.save_params\n",
" * fluid.io.save_persistables\n",
" * fluid.io.save_inference_model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# 方法一:使用动态图 model.save()保存模型和优化器参数信息\n",
"model.save('mnist_checkpoint/test')"
"后面将对paddle高阶API的模型保存与加载的方法进行讲解。\n",
"#### 方法一:\n",
"* paddle.Model.fit(train_data, epochs, batch_size, save_dir, log_freq) <br><br>\n",
"在使用model.fit函数进行网络循环训练时,在save_dir参数中指定保存模型的路径,save_freq指定写入频率,即可同时实现模型的训练和保存。mode.fit()只能保存模型参数,不能保存优化器参数,每个epoch结束只会生成一个.pdparams文件。可以边训练边保存,每次epoch结束会实时生成一个.pdparams文件。 \n",
"\n",
"#### 方法二:\n",
"* paddle.Model.save(path) <br><br>\n",
"model.save(path)方法可以保存网络参数和优化器参数,每个epoch会生成两种文件 0.pdparams,0.pdopt,分别存储了模型参数和优化器参数,但是只会在整个模型训练完成后才会生成参数文件,path的格式为'dirname/file_prefix' 或 'file_prefix',其中dirname指定路径名称,file_prefix 指定参数文件的名称。"
]
},
{
...
...
@@ -308,7 +199,7 @@
"metadata": {},
"outputs": [],
"source": [
"# 方法
二
:训练过程中实时保存每个epoch的模型参数\n",
"# 方法
一
:训练过程中实时保存每个epoch的模型参数\n",
"model.fit(train_dataset,\n",
" test_dataset,\n",
" epochs=2,\n",
...
...
@@ -317,6 +208,16 @@
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 方法二:model.save()保存模型和优化器参数信息\n",
"model.save('mnist_checkpoint/test')"
]
},
{
"cell_type": "markdown",
"metadata": {},
...
...
@@ -324,25 +225,24 @@
"## 加载模型参数\n",
"\n",
"当恢复训练状态时,需要加载模型数据,此时我们可以使用加载函数从存储模型状态和优化器状态的文件中载入模型参数和优化器参数,如果不需要恢复优化器,则不必使用优化器状态文件。\n",
"\n",
"#### 静态图加载模型参数\n",
"* fluid.io.load_vars<br>\n",
"通过执行器(Executor)加载指定目录中的变量。加载变量的方式有两种:\n",
"1)通过接口中的 vars 指定需要加载的变量列表。\n",
"2)将一个已经存在的程序(Program)赋值给接口中的 main_program,然后这个程序中的所有变量都将被加载。\n",
"第一种加载方式的优先级要高于第二种。\n",
"* fluid.io.load_params<br>\n",
"该接口从 main_program 指定的程序中过滤出全部参数(Parameter),并试图从 dirname 指定的文件夹或 filename 指定的文件中加载这些参数。\n",
"* fluid.io.load_persistables<br>\n",
"该接口从 main_program 指定的程序中过滤出全部持久性变量(persistable==True),并试图从 dirname 指定的文件夹或 filename 指定的文件中加载这些持久性变量。\n",
"* fluid.io.load_inference_model<br>\n",
"存储预测模型时,一般通过 fluid.io.save_inference_model 接口对默认的 fluid.Program 进行裁剪,只保留预测 predict_var 所需部分。 裁剪后的 program 会保存在指定路径 ./infer_model/__model__ 下,参数会保存到 ./infer_model 下的各个独立文件。\n",
"\n",
"#### 动态图加载模型参数\n",
"* paddle.fluid.dygraph.load_dygraph(model_path)<br>\n",
"该接口尝试从磁盘中加载参数或优化器的 dict,该接口会同时加载 model_path + \".pdparams\" 和 model_path + \".pdopt\" 中的内容。其中model_path参数保存state_dict的文件的前缀,该路径不应该包括后缀 .pdparams 或 .pdopt。此函数返回两个 dict ,即从文件中恢复的参数 dict 和优化器 dict。\n",
"* paddle.incubate.hapi.model.Model.load(self, path, skip_mismatch=False, reset_optimizer=False)<br>\n",
"从存储的模型和优化器的参数文件种加载。如果不需要恢复优化器,则不必恢复优化器参数文件。"
"#### 高阶API-模型加载\n",
" * paddle.Model.load()\n",
"#### paddle 基础框架-动态图-模型加载\n",
" * paddle.fluid.dygraph.load_dygraph\n",
"#### paddle 基础框架-静态图-模型加载\n",
" * fluid.io.load_vars \n",
" * fluid.io.load_params \n",
" * fluid.io.load_persistables\n",
" * fluid.io.load_inference_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面将对高阶API的模型加载方法进行讲解\n",
"* model.load(self, path, skip_mismatch=False, reset_optimizer=False)<br><br>\n",
"model.load能够同时加载模型和优化器参数。通过reset_optimizer参数来指定是否需要恢复优化器参数,若reset_optimizer参数为True,则重新初始化优化器参数,若reset_optimizer参数为False,则从路径中恢复优化器参数。"
]
},
{
...
...
@@ -351,7 +251,7 @@
"metadata": {},
"outputs": [],
"source": [
"#
使用动态图 model.load()加载模型和优化器参数信息
\n",
"#
高阶API加载模型
\n",
"model.load('mnist_checkpoint/test')"
]
},
...
...
@@ -359,7 +259,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##
如何判断模型是否准确的恢复训练呢?
\n",
"##
恢复训练
\n",
"\n",
"理想的恢复训练是模型状态回到训练中断的时刻,恢复训练之后的梯度更新走向是和恢复训练前的梯度走向完全相同的。基于此,我们可以通过恢复训练后的损失变化,判断上述方法是否能准确的恢复训练。即从epoch 0结束时保存的模型参数和优化器状态恢复训练,校验其后训练的损失变化(epoch 1)是否和不中断时的训练完全一致。\n",
"\n",
...
...
@@ -374,19 +274,52 @@
},
{
"cell_type": "code",
"execution_count":
null
,
"execution_count":
15
,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"step 100/938 - loss: 1.4618 - acc_top1: 0.9709 - acc_top2: 0.9917 - 16ms/step\n",
"step 200/938 - loss: 1.5390 - acc_top1: 0.9698 - acc_top2: 0.9921 - 16ms/step\n",
"step 300/938 - loss: 1.4938 - acc_top1: 0.9695 - acc_top2: 0.9911 - 16ms/step\n",
"step 400/938 - loss: 1.4946 - acc_top1: 0.9698 - acc_top2: 0.9914 - 16ms/step\n",
"step 500/938 - loss: 1.4746 - acc_top1: 0.9709 - acc_top2: 0.9918 - 16ms/step\n",
"step 600/938 - loss: 1.4788 - acc_top1: 0.9711 - acc_top2: 0.9917 - 16ms/step\n",
"step 700/938 - loss: 1.4770 - acc_top1: 0.9717 - acc_top2: 0.9920 - 15ms/step\n",
"step 800/938 - loss: 1.4848 - acc_top1: 0.9721 - acc_top2: 0.9920 - 15ms/step\n",
"step 900/938 - loss: 1.5047 - acc_top1: 0.9730 - acc_top2: 0.9920 - 15ms/step\n",
"step 938/938 - loss: 1.4612 - acc_top1: 0.9731 - acc_top2: 0.9921 - 15ms/step\n",
"Eval begin...\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9733 - acc_top2: 0.9938 - 5ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9792 - acc_top2: 0.9952 - 5ms/step\n",
"Eval samples: 10000\n",
"Epoch 2/2\n",
"step 100/938 - loss: 1.4706 - acc_top1: 0.9778 - acc_top2: 0.9938 - 15ms/step\n",
"step 200/938 - loss: 1.4819 - acc_top1: 0.9777 - acc_top2: 0.9940 - 15ms/step\n",
"step 300/938 - loss: 1.4612 - acc_top1: 0.9778 - acc_top2: 0.9936 - 15ms/step\n",
"step 400/938 - loss: 1.4911 - acc_top1: 0.9786 - acc_top2: 0.9940 - 15ms/step\n",
"step 500/938 - loss: 1.4938 - acc_top1: 0.9785 - acc_top2: 0.9940 - 16ms/step\n",
"step 600/938 - loss: 1.4769 - acc_top1: 0.9783 - acc_top2: 0.9941 - 16ms/step\n",
"step 700/938 - loss: 1.4612 - acc_top1: 0.9785 - acc_top2: 0.9942 - 16ms/step\n",
"step 800/938 - loss: 1.4763 - acc_top1: 0.9784 - acc_top2: 0.9942 - 16ms/step\n",
"step 900/938 - loss: 1.4763 - acc_top1: 0.9785 - acc_top2: 0.9942 - 16ms/step\n",
"step 938/938 - loss: 1.4619 - acc_top1: 0.9787 - acc_top2: 0.9943 - 16ms/step\n",
"Eval begin...\n",
"step 100/157 - loss: 1.4612 - acc_top1: 0.9773 - acc_top2: 0.9950 - 6ms/step\n",
"step 157/157 - loss: 1.4612 - acc_top1: 0.9815 - acc_top2: 0.9959 - 6ms/step\n",
"Eval samples: 10000\n"
]
}
],
"source": [
"import paddle\n",
"import paddle.fluid as fluid\n",
"import paddle.hapi as hapi\n",
"from paddle.nn import functional\n",
"from paddle.hapi.model import Model\n",
"from paddle.vision.datasets import MNIST\n",
"from paddle.metric import Accuracy\n",
"from paddle.nn import Conv2d,Pool2D,Linear\n",
"from paddle.vision import LeNet\n",
"from paddle.static import InputSpec\n",
"#\n",
"#\n",
...
...
@@ -398,18 +331,19 @@
"\n",
"inputs = InputSpec([None, 784], 'float32', 'x')\n",
"labels = InputSpec([None, 10], 'float32', 'x')\n",
"model =
hapi.Model(LeNet
(), inputs, labels)\n",
"model =
paddle.Model(MyModel
(), inputs, labels)\n",
"optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n",
"model.load(\"../教程/mnist_checkpoint/final\")\n",
"model.prepare( \n",
" optim,\n",
" paddle.nn.loss.CrossEntropyLoss(),\n",
" Accuracy(topk=(1, 2))\n",
" )\n",
"model.load(params_path)\n",
"model.fit(train_data=train_dataset,\n",
" eval_data=test_dataset,\n",
" batch_size=64,\n",
" epochs=5\n",
" log_freq=100,\n",
" epochs=2\n",
" )"
]
},
...
...
@@ -424,7 +358,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"以上就是用Mnist手写数字识别的例子对保存模型
和加载模型进行讲解,Paddle提供很
很多保存和加载的API方法,您可以根据自己的需求进行选择。"
"以上就是用Mnist手写数字识别的例子对保存模型
、加载模型、恢复训练进行讲解,Paddle提供了
很多保存和加载的API方法,您可以根据自己的需求进行选择。"
]
},
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录