未验证 提交 c1362a6c 编写于 作者: Z Zhen Wang 提交者: GitHub

Rewrite the content of load_inference_model_cn.rst. (#1218)

* rewrite the content of load_inference_model_cn.rst.
上级 7ffc29d8
......@@ -5,26 +5,33 @@ load_inference_model
.. py:function:: paddle.fluid.io.load_inference_model(dirname, executor, model_filename=None, params_filename=None, pserver_endpoints=None)
从指定目录中加载预测模型(inference model)。通过这个API,您可以获得模型结构(预测程序)和模型参数。如果您只想下载预训练后的模型的参数,请使用load_params API。更多细节请参考 ``模型/变量的保存、载入与增量训练`` 。
从指定文件路径中加载预测模型(Inference Model),即调用该接口可获得模型结构(Inference Program)和模型参数。若只想加载预训练后的模型参数,请使用 :ref:`cn_api_fluid_io_load_params` 接口。更多细节请参考 :ref:`api_guide_model_save_reader` 。
参数:
- **dirname** (str) – model的路径
- **executor** (Executor) – 运行 inference model的 ``executor``
- **model_filename** (str|None) – 存储着预测 Program 的文件名称。如果设置为None,将使用默认的文件名为: ``__model__``
- **params_filename** (str|None) – 加载所有相关参数的文件名称。如果设置为None,则参数将保存在单独的文件中
- **pserver_endpoints** (list|None) – 只有在分布式预测时需要用到。 当在训练时使用分布式 look up table , 需要这个参数. 该参数是 pserver endpoints 的列表
参数
- **dirname** (str) – 待加载模型的存储路径。
- **executor** (Executor) – 运行 Inference Model 的 ``executor`` ,详见 :ref:`api_guide_executor` 。
- **model_filename** (str,可选) – 存储Inference Program结构的文件名称。如果设置为None,则使用 ``__model__`` 作为默认的文件名。默认值为None。
- **params_filename** (str,可选) – 存储所有模型参数的文件名称。当且仅当所有模型参数被保存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None。默认值为None
- **pserver_endpoints** (list,可选) – 只有在分布式预测时才需要用到。当训练过程中使用分布式查找表(distributed lookup table)时, 预测时需要指定pserver_endpoints的值。它是 pserver endpoints 的列表,默认值为None。
返回: 这个函数的返回有三个元素的元组(Program,feed_target_names, fetch_targets)。Program 是一个 ``Program`` ,它是预测 ``Program``。 ``feed_target_names`` 是一个str列表,它包含需要在预测 ``Program`` 中提供数据的变量的名称。``fetch_targets`` 是一个 ``Variable`` 列表,从中我们可以得到推断结果。
返回:该接口返回一个包含三个元素的列表(program,feed_target_names, fetch_targets)。它们的含义描述如下:
- **program** (Program)– ``Program`` (详见 :ref:`api_guide_Program` )类的实例。此处它被用于预测,因此可被称为Inference Program。
- **feed_target_names** (list)– 字符串列表,包含着Inference Program预测时所需提供数据的所有变量名称(即所有输入变量的名称)。
- **fetch_targets** (list)– ``Variable`` (详见 :ref:`api_guide_Program` )类型列表,包含着模型的所有输出变量。通过这些输出变量即可得到模型的预测结果。
返回类型:元组(tuple)
**返回类型:** 列表(list)
抛出异常:
- ``ValueError`` – 如果 ``dirname`` 非法
- ``ValueError`` – 如果接口参数 ``dirname`` 指向一个不存在的文件路径,则抛出异常。
**代码示例**
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
# 构建模型
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
......@@ -36,26 +43,29 @@ load_inference_model
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
# 保存预测模型
path = "./infer_model"
fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],target_vars=[hidden_b], executor=exe, main_program=main_prog)
tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
# 示例一: 不需要指定分布式查找表的模型加载示例,即训练时未用到distributed lookup table。
[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(dirname=path, executor=exe))
tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# endpoints是pserver服务器终端列表,下面仅为一个样例
# 示例二: 若训练时使用了distributed lookup table,则模型加载时需要通过endpoints参数指定pserver服务器结点列表。
# pserver服务器结点列表主要用于分布式查找表进行ID查找时使用。下面的["127.0.0.1:2023","127.0.0.1:2024"]仅为一个样例。
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
# 如果需要查询表格,我们可以使用:
[dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
fluid.io.load_inference_model(dirname=path,
executor=exe,
pserver_endpoints=endpoints))
# 在这个示例中,inference program 保存在“ ./infer_model/__model__”中
# 参数保存在“./infer_mode ”单独的若干文件
# 加载 inference program 后, executor 使用 fetch_targets 和 feed_target_names 执行Program,得到预测结果
# 在上述示例中,inference program 被保存在“ ./infer_model/__model__”文件内,
# 参数保存在“./infer_mode ”单独的若干文件内。
# 加载 inference program 后, executor可使用 fetch_targets 和 feed_target_names 执行Program,并得到预测结果。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册