提交 c98577a3 编写于 作者: L lujun 提交者: Cheerego

fix dygraph load_persistables bug (#817)

上级 ec4bf59a
...@@ -391,7 +391,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供: ...@@ -391,7 +391,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供:

在模型训练中可以使用` fluid.dygraph.save_persistables(your_model_object.state_dict(), "save_dir")`来保存`your_model_object`中所有的模型参数。也可以自定义需要保存的“参数名” - “参数对象”的Python Dictionary传入。 
在模型训练中可以使用` fluid.dygraph.save_persistables(your_model_object.state_dict(), "save_dir")`来保存`your_model_object`中所有的模型参数。也可以自定义需要保存的“参数名” - “参数对象”的Python Dictionary传入。
同样可以使用`your_modle_object.load_dict( 同样可以使用`your_modle_object.load_dict(
fluid.dygraph.load_persistables(your_model_object.state_dict(), "save_dir"))`接口来恢复保存的模型参数从而达到继续训练的目的。 fluid.dygraph.load_persistables("save_dir"))`接口来恢复保存的模型参数从而达到继续训练的目的。
下面的代码展示了如何在“手写数字识别”任务中保存参数并且读取已经保存的参数来继续训练。 下面的代码展示了如何在“手写数字识别”任务中保存参数并且读取已经保存的参数来继续训练。
...@@ -421,7 +421,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供: ...@@ -421,7 +421,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供:
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
mnist.load_dict(fluid.dygraph.load_persistables(mnist.state_dict(), "save_dir")) mnist.load_dict(fluid.dygraph.load_persistables("save_dir"))
restore = mnist.parameters() restore = mnist.parameters()
# check save and load # check save and load
success = True success = True
...@@ -490,7 +490,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供: ...@@ -490,7 +490,7 @@ PaddlePaddle DyGraph是一个更加灵活易用的模式,可提供:
mnist_infer = MNIST("mnist") mnist_infer = MNIST("mnist")
# load checkpoint # load checkpoint
mnist_infer.load_dict( mnist_infer.load_dict(
fluid.dygraph.load_persistables(mnist.state_dict(), "save_dir")) fluid.dygraph.load_persistables("save_dir"))
print("checkpoint loaded") print("checkpoint loaded")
# start evaluate mode # start evaluate mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册