提交 83a1091e 编写于 作者: S songyouwei

new save load in README, and typo fix

test=develop
上级 a4d7a2e4
......@@ -4,3 +4,4 @@
*.pyc
*~
*.vscode
*.idea
\ No newline at end of file
......@@ -136,7 +136,7 @@ def download_cycle_pix(dir_path, dataname):
for zip_file in zip_f.namelist():
zip_f.extract(zip_file, dir_path)
### generator .txt file according to dirs
### generate .txt file according to dirs
dirs = os.listdir(os.path.join(dir_path, '{}'.format(dataname)))
for d in dirs:
txt_file = d + '.txt'
......
......@@ -47,9 +47,9 @@ Loss at epoch 0 step 400: [0.03940877]
```
## 参数保存
调用`fluid.dygraph.save_persistables()`接口可以把模型的参数进行保存。
调用`fluid.save_dygraph()`接口可以把模型的参数进行保存。
```python
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
fluid.save_dygraph(mnist.state_dict(), "save_dir")
```
## 测试
......@@ -59,14 +59,18 @@ mnist.eval()
```
## 模型评估
我们使用手写数据集中的一张图片来进行评估。为了区别训练模型,我们使用`with fluid.dygraph.guard()`来切换到一个新的参数空间,然后构建一个用于评估的网络`mnist_infer`,并通过`mnist_infer.load_dict()`来加载使用`fluid.dygraph.load_persistables`读取的参数。然后用`mnist_infer.eval()`切换到评估。
我们使用手写数据集中的一张图片来进行评估。
为了区别训练模型,我们使用`with fluid.dygraph.guard()`来切换到一个新的参数空间,
然后构建一个用于评估的网络`mnist_infer`
并通过`mnist_infer.set_dict()`来加载使用`fluid.load_dygraph`读取的参数。
然后用`mnist_infer.eval()`切换到评估。
```python
with fluid.dygraph.guard():
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(
fluid.dygraph.load_persistables("save_dir"))
mnist_infer.set_dict(
fluid.load_dygraph("save_dir"))
# start evaluate mode
mnist_infer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册