Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
FluidDoc
提交
c1362a6c
F
FluidDoc
项目概览
PaddlePaddle
/
FluidDoc
通知
5
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
23
列表
看板
标记
里程碑
合并请求
111
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
F
FluidDoc
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
23
Issue
23
列表
看板
标记
里程碑
合并请求
111
合并请求
111
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c1362a6c
编写于
9月 23, 2019
作者:
Z
Zhen Wang
提交者:
GitHub
9月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite the content of load_inference_model_cn.rst. (#1218)
* rewrite the content of load_inference_model_cn.rst.
上级
7ffc29d8
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
27 addition
and
17 deletion
+27
-17
doc/fluid/api_cn/io_cn/load_inference_model_cn.rst
doc/fluid/api_cn/io_cn/load_inference_model_cn.rst
+27
-17
未找到文件。
doc/fluid/api_cn/io_cn/load_inference_model_cn.rst
浏览文件 @
c1362a6c
...
...
@@ -5,26 +5,33 @@ load_inference_model
.. py:function:: paddle.fluid.io.load_inference_model(dirname, executor, model_filename=None, params_filename=None, pserver_endpoints=None)
从指定
目录中加载预测模型(inference model)。通过这个API,您可以获得模型结构(预测程序)和模型参数。如果您只想下载预训练后的模型的参数,请使用load_params API。更多细节请参考 ``模型/变量的保存、载入与增量训练`
` 。
从指定
文件路径中加载预测模型(Inference Model),即调用该接口可获得模型结构(Inference Program)和模型参数。若只想加载预训练后的模型参数,请使用 :ref:`cn_api_fluid_io_load_params` 接口。更多细节请参考 :ref:`api_guide_model_save_reader
` 。
参数
:
- **dirname** (str) –
model的路径
- **executor** (Executor) – 运行
inference model的 ``executor``
- **model_filename** (str
|None) – 存储着预测 Program 的文件名称。如果设置为None,将使用默认的文件名为: ``__model__``
- **params_filename** (str
|None) – 加载所有相关参数的文件名称。如果设置为None,则参数将保存在单独的文件中
。
- **pserver_endpoints** (list
|None) – 只有在分布式预测时需要用到。 当在训练时使用分布式 look up table , 需要这个参数. 该参数是 pserver endpoints 的列表
参数
:
- **dirname** (str) –
待加载模型的存储路径。
- **executor** (Executor) – 运行
Inference Model 的 ``executor`` ,详见 :ref:`api_guide_executor` 。
- **model_filename** (str
,可选) – 存储Inference Program结构的文件名称。如果设置为None,则使用 ``__model__`` 作为默认的文件名。默认值为None。
- **params_filename** (str
,可选) – 存储所有模型参数的文件名称。当且仅当所有模型参数被保存在一个单独的二进制文件中,它才需要被指定。如果模型参数是存储在各自分离的文件中,设置它的值为None。默认值为None
。
- **pserver_endpoints** (list
,可选) – 只有在分布式预测时才需要用到。当训练过程中使用分布式查找表(distributed lookup table)时, 预测时需要指定pserver_endpoints的值。它是 pserver endpoints 的列表,默认值为None。
返回: 这个函数的返回有三个元素的元组(Program,feed_target_names, fetch_targets)。Program 是一个 ``Program`` ,它是预测 ``Program``。 ``feed_target_names`` 是一个str列表,它包含需要在预测 ``Program`` 中提供数据的变量的名称。``fetch_targets`` 是一个 ``Variable`` 列表,从中我们可以得到推断结果。
返回:该接口返回一个包含三个元素的列表(program,feed_target_names, fetch_targets)。它们的含义描述如下:
- **program** (Program)– ``Program`` (详见 :ref:`api_guide_Program` )类的实例。此处它被用于预测,因此可被称为Inference Program。
- **feed_target_names** (list)– 字符串列表,包含着Inference Program预测时所需提供数据的所有变量名称(即所有输入变量的名称)。
- **fetch_targets** (list)– ``Variable`` (详见 :ref:`api_guide_Program` )类型列表,包含着模型的所有输出变量。通过这些输出变量即可得到模型的预测结果。
返回类型:元组(tuple)
**返回类型:** 列表(list)
抛出异常:
- ``ValueError`` – 如果 ``dirname`` 非法
- ``ValueError`` – 如果接口参数 ``dirname`` 指向一个不存在的文件路径,则抛出异常。
**代码示例**
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
# 构建模型
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
...
...
@@ -36,26 +43,29 @@ load_inference_model
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
# 保存预测模型
path = "./infer_model"
fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],target_vars=[hidden_b], executor=exe, main_program=main_prog)
tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(dirname=path, executor=exe))
# 示例一: 不需要指定分布式查找表的模型加载示例,即训练时未用到distributed lookup table。
[inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(dirname=path, executor=exe))
tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# endpoints是pserver服务器终端列表,下面仅为一个样例
# 示例二: 若训练时使用了distributed lookup table,则模型加载时需要通过endpoints参数指定pserver服务器结点列表。
# pserver服务器结点列表主要用于分布式查找表进行ID查找时使用。下面的["127.0.0.1:2023","127.0.0.1:2024"]仅为一个样例。
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
# 如果需要查询表格,我们可以使用:
[dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
fluid.io.load_inference_model(dirname=path,
executor=exe,
pserver_endpoints=endpoints))
# 在
这个示例中,inference program 保存在“ ./infer_model/__model__”中
# 参数保存在“./infer_mode ”单独的若干文件
中
# 加载 inference program 后, executor
使用 fetch_targets 和 feed_target_names 执行Program,得到预测结果
# 在
上述示例中,inference program 被保存在“ ./infer_model/__model__”文件内,
# 参数保存在“./infer_mode ”单独的若干文件
内。
# 加载 inference program 后, executor
可使用 fetch_targets 和 feed_target_names 执行Program,并得到预测结果。
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录