提交 e885d95c 编写于 作者: T tangwei12 提交者: Cheerego

save/load implement tiny optimize. (#801)

* save load implement optimize

* fix.
上级 24a89167
.. _user_guide_save_load_vars:
#############################
模型/变量的保存/载入与增量训练
#############################
################################
模型/变量的保存载入与增量训练
################################
模型变量分类
############
......@@ -37,6 +37,21 @@
那么我们应该将各种长期变量都保存下来,甚至还需要记录一下当前的epoch和step的id。
因为一些模型变量虽然不是参数,但对于模型的训练依然必不可少。
save_vars、save_params、save_persistables 以及 save_inference_model的区别
##########################################################################
1. :code:`save_inference_model` 会根据用户配置的 :code:`feeded_var_names` 和 :code:`target_vars` 进行网络裁剪,保存下裁剪后的网络结构的 ``__model__`` 以及裁剪后网络中的长期变量
2. :code:`save_persistables` 不会保存网络结构,会保存网络中的全部长期变量到指定位置。
3. :code:`save_params` 不会保存网络结构,会保存网络中的全部模型参数到指定位置。
4. :code:`save_vars` 不会保存网络结构,会根据用户指定的 :code:`fluid.framework.Parameter` 列表进行保存。
:code:`save_persistables` 保存的网络参数是最全面的,如果是增量训练或者恢复训练, 请选择 :code:`save_persistables` 进行变量保存。
:code:`save_inference_model` 会保存网络参数及裁剪后的模型,如果后续要做预测相关的工作, 请选择 :code:`save_inference_model` 进行变量和网络的保存。
:code:`save_vars 和 save_params` 仅在用户了解清楚用途及特殊目的情况下使用, 一般不建议使用。
保存模型用于对新样本的预测
==========================
......@@ -97,8 +112,6 @@
另外,需特别注意运行 :code:`fluid.default_startup_program()` 必须在调用 :code:`fluid.io.load_params`
之前。如果在之后运行,可能会覆盖已加载的模型参数导致错误。
预测模型的保存和加载
##############################
......@@ -162,6 +175,7 @@
1. 在训练的最后调用 :code:`fluid.io.save_persistables` 保存长期变量时,不必要所有的trainer都调用这个方法来保存,一般0号trainer来保存即可。
2. 多机增量训练的参数加载在PServer端,trainer端不用加载参数。在PServer全部启动后,trainer会从PServer端同步参数。
3. 在确认需要使用增量的情况下, 多机在调用 :code:`fluid.DistributeTranspiler.transpile` 时需要指定 ``current_endpoint`` 参数。
多机增量(不带分布式大规模稀疏矩阵)训练的一般步骤为:
......@@ -211,7 +225,7 @@
training_role == "PSERVER"
config = fluid.DistributeTranspilerConfig()
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers, sync_mode=True)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers, sync_mode=True, current_endpoint=current_endpoint)
if training_role == "PSERVER":
current_endpoint = "127.0.0.1:1001"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册