Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
d2l-zh
提交
c03f73ce
D
d2l-zh
项目概览
OpenDocCN
/
d2l-zh
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
d2l-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c03f73ce
编写于
9月 14, 2017
作者:
M
muli
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add serialization
上级
c86112db
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
96 addition
and
0 deletion
+96
-0
chapter03_gluon-basics/serialization.md
chapter03_gluon-basics/serialization.md
+95
-0
index.rst
index.rst
+1
-0
未找到文件。
chapter03_gluon-basics/serialization.md
0 → 100644
浏览文件 @
c03f73ce
# 序列化 --- 读写模型
我们现在已经讲了很多,包括
-
如何处理数据
-
如何构建模型
-
如何在数据上训练模型
-
如何使用不同的损失函数来做分类和回归
但即使知道了所有这些,我们还没有完全准备好来构建一个真正的机器学习系统。这是因为我们还没有讲如何读和写模型。因为现实中,我们通常在一个地方训练好模型,然后部署到很多不同的地方。我们需要把内存中的训练好的模型存在硬盘上好下次使用。
## 读写NDArrays
作为开始,我们先看看如何读写NDArray。虽然我们可以使用Python的序列化包例如
`Pickle`
,不过我们更倾向直接
`save`
和
`load`
,通常这样更快,而且别的语言,例如R和Scala也能用到。
```
{.python .input n=2}
from mxnet import nd
x = nd.ones(3)
y = nd.zeros(4)
filename = "../data/test1.params"
nd.save(filename, [x, y])
```
读回来
```
{.python .input n=3}
a, b = nd.load(filename)
print(a, b)
```
不仅可以读写单个NDArray,NDArray list,dict也是可以的:
```
{.python .input n=4}
mydict = {"x": x, "y": y}
filename = "../data/test2.params"
nd.save(filename, mydict)
```
```
{.python .input n=5}
c = nd.load(filename)
print(c)
```
## 读写Gluon模型的参数
跟NDArray类似,Gluon的模型(就是
`nn.Block`
)提供便利的
`save_params`
和
`load_params`
函数来读写数据。我们同前一样创建一个简单的多层感知机
```
{.python .input n=6}
from mxnet.gluon import nn
def get_net():
net = nn.Sequential()
with net.name_scope():
net.add(nn.Dense(10, activation="relu"))
net.add(nn.Dense(2))
return net
net = get_net()
net.initialize()
x = nd.random.uniform(shape=(2,10))
print(net(x))
```
下面我们把模型参数存起来
```
{.python .input}
filename = "../data/mlp.params"
net.save_params(filename)
```
之后我们构建一个一样的多层感知机,但不像前面那样随机初始化,我们直接读取前面的模型参数。这样给定同样的输入,新的模型应该会输出同样的结果。
```
{.python .input n=8}
import mxnet as mx
net2 = get_net()
net2.load_params(filename, mx.cpu())
print(net2(x))
```
更进一步的,下面的代码尝试将模型直接读入GPU中
```
{.python .input}
import sys
sys.path.append('..')
import utils
net3 = get_net()
net3.load_params(filename, utils.try_gpu())
print(net3(x.as_in_context(utils.try_gpu())))
```
## 总结
通过
`load_params`
和
`save_params`
可以很方便的读写模型参数。
index.rst
浏览文件 @
c03f73ce
...
@@ -46,6 +46,7 @@ Github源代码在 `https://github.com/mli/gluon-tutorials-zh <https://github.co
...
@@ -46,6 +46,7 @@ Github源代码在 `https://github.com/mli/gluon-tutorials-zh <https://github.co
block
block
parameters
parameters
serialization
use-gpu
use-gpu
.. toctree::
.. toctree::
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录