Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4ceedec3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4ceedec3
编写于
11月 27, 2020
作者:
S
Shibo Tao
提交者:
GitHub
11月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance doc. add kwargs for backward compatibility. test=develop (#29143)
上级
28280647
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
86 addition
and
17 deletion
+86
-17
python/paddle/static/io.py
python/paddle/static/io.py
+86
-17
未找到文件。
python/paddle/static/io.py
浏览文件 @
4ceedec3
...
@@ -204,7 +204,7 @@ def is_persistable(var):
...
@@ -204,7 +204,7 @@ def is_persistable(var):
@
static_only
@
static_only
def
serialize_program
(
feed_vars
,
fetch_vars
):
def
serialize_program
(
feed_vars
,
fetch_vars
,
**
kwargs
):
"""
"""
:api_attr: Static Graph
:api_attr: Static Graph
...
@@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars):
...
@@ -213,6 +213,10 @@ def serialize_program(feed_vars, fetch_vars):
Args:
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
Returns:
bytes: serialized program.
bytes: serialized program.
...
@@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars):
...
@@ -235,7 +239,6 @@ def serialize_program(feed_vars, fetch_vars):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
exe.run(paddle.static.default_startup_program())
...
@@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars):
...
@@ -252,7 +255,7 @@ def serialize_program(feed_vars, fetch_vars):
# verify fetch_vars
# verify fetch_vars
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
()
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
)
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
return
_serialize_program
(
program
)
return
_serialize_program
(
program
)
...
@@ -265,7 +268,7 @@ def _serialize_program(program):
...
@@ -265,7 +268,7 @@ def _serialize_program(program):
@
static_only
@
static_only
def
serialize_persistables
(
feed_vars
,
fetch_vars
,
executor
):
def
serialize_persistables
(
feed_vars
,
fetch_vars
,
executor
,
**
kwargs
):
"""
"""
:api_attr: Static Graph
:api_attr: Static Graph
...
@@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
...
@@ -274,6 +277,10 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
Args:
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
Returns:
bytes: serialized program.
bytes: serialized program.
...
@@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
...
@@ -296,7 +303,6 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
exe.run(paddle.static.default_startup_program())
...
@@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
...
@@ -314,7 +320,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor):
# verify fetch_vars
# verify fetch_vars
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
()
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
)
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
return
_serialize_persistables
(
program
,
executor
)
return
_serialize_persistables
(
program
,
executor
)
...
@@ -380,7 +386,8 @@ def save_to_file(path, content):
...
@@ -380,7 +386,8 @@ def save_to_file(path, content):
@
static_only
@
static_only
def
save_inference_model
(
path_prefix
,
feed_vars
,
fetch_vars
,
executor
):
def
save_inference_model
(
path_prefix
,
feed_vars
,
fetch_vars
,
executor
,
**
kwargs
):
"""
"""
:api_attr: Static Graph
:api_attr: Static Graph
...
@@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
...
@@ -396,6 +403,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
fetch_vars(Variable | list[Variable]): Variables returned by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
executor(Executor): The executor that saves the inference model. You can refer
executor(Executor): The executor that saves the inference model. You can refer
to :ref:`api_guide_executor_en` for more details.
to :ref:`api_guide_executor_en` for more details.
kwargs: Supported keys including 'program'.
Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
Returns:
Returns:
None
None
...
@@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
...
@@ -418,7 +428,6 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
predict = paddle.static.nn.fc(image, 10, activation='softmax')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
exe.run(paddle.static.default_startup_program())
...
@@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
...
@@ -456,7 +465,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
# verify fetch_vars
# verify fetch_vars
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
()
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
)
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
_normalize_program
(
program
,
feed_vars
,
fetch_vars
)
# serialize and save program
# serialize and save program
program_bytes
=
_serialize_program
(
program
)
program_bytes
=
_serialize_program
(
program
)
...
@@ -475,8 +484,35 @@ def deserialize_program(data):
...
@@ -475,8 +484,35 @@ def deserialize_program(data):
Args:
Args:
data(bytes): serialized program.
data(bytes): serialized program.
Returns:
Returns:
Program: deserialized program.
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
"""
"""
program
=
Program
.
parse_from_string
(
data
)
program
=
Program
.
parse_from_string
(
data
)
if
not
core
.
_is_program_version_supported
(
program
.
_version
()):
if
not
core
.
_is_program_version_supported
(
program
.
_version
()):
...
@@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor):
...
@@ -496,8 +532,37 @@ def deserialize_persistables(program, data, executor):
program(Program): program that contains parameter names (to deserialize).
program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters.
data(bytes): serialized parameters.
executor(Executor): executor used to run load op.
executor(Executor): executor used to run load op.
Returns:
Returns:
Program: deserialized program.
Program: deserialized program.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
"""
"""
if
not
isinstance
(
program
,
Program
):
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
raise
TypeError
(
...
@@ -567,7 +632,7 @@ def load_from_file(path):
...
@@ -567,7 +632,7 @@ def load_from_file(path):
@
static_only
@
static_only
def
load_inference_model
(
path_prefix
,
executor
,
**
confi
gs
):
def
load_inference_model
(
path_prefix
,
executor
,
**
kwar
gs
):
"""
"""
:api_attr: Static Graph
:api_attr: Static Graph
...
@@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs):
...
@@ -580,6 +645,10 @@ def load_inference_model(path_prefix, executor, **configs):
- Set to None when reading the model from memory.
- Set to None when reading the model from memory.
executor(Executor): The executor to run for loading inference model.
executor(Executor): The executor to run for loading inference model.
See :ref:`api_guide_executor_en` for more details about it.
See :ref:`api_guide_executor_en` for more details about it.
kwargs: Supported keys including 'model_filename', 'params_filename'.
Attention please, kwargs is used for backward compatibility mainly.
- model_filename(str): specify model_filename if you don't want to use default name.
- params_filename(str): specify params_filename if you don't want to use default name.
Returns:
Returns:
list: The return of this API is a list with three elements:
list: The return of this API is a list with three elements:
...
@@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs):
...
@@ -631,17 +700,17 @@ def load_inference_model(path_prefix, executor, **configs):
# fetch_targets, we can use an executor to run the inference
# fetch_targets, we can use an executor to run the inference
# program to get the inference result.
# program to get the inference result.
"""
"""
# check
confi
gs
# check
kwar
gs
supported_args
=
(
'model_filename'
,
'params_filename'
)
supported_args
=
(
'model_filename'
,
'params_filename'
)
deprecated_args
=
(
'pserver_endpoints'
,
)
deprecated_args
=
(
'pserver_endpoints'
,
)
caller
=
inspect
.
currentframe
().
f_code
.
co_name
caller
=
inspect
.
currentframe
().
f_code
.
co_name
_check_args
(
caller
,
confi
gs
,
supported_args
,
deprecated_args
)
_check_args
(
caller
,
kwar
gs
,
supported_args
,
deprecated_args
)
# load from memory
# load from memory
if
path_prefix
is
None
:
if
path_prefix
is
None
:
_logger
.
warning
(
"Load inference model from memory is deprecated."
)
_logger
.
warning
(
"Load inference model from memory is deprecated."
)
model_filename
=
confi
gs
.
get
(
'model_filename'
,
None
)
model_filename
=
kwar
gs
.
get
(
'model_filename'
,
None
)
params_filename
=
confi
gs
.
get
(
'params_filename'
,
None
)
params_filename
=
kwar
gs
.
get
(
'params_filename'
,
None
)
if
params_filename
is
None
:
if
params_filename
is
None
:
raise
ValueError
(
raise
ValueError
(
"params_filename cannot be None when path_prefix is None."
)
"params_filename cannot be None when path_prefix is None."
)
...
@@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs):
...
@@ -655,14 +724,14 @@ def load_inference_model(path_prefix, executor, **configs):
# set model_path and params_path in new way,
# set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case.
# path_prefix represents a file path without suffix in this case.
if
not
confi
gs
:
if
not
kwar
gs
:
model_path
=
path_prefix
+
".pdmodel"
model_path
=
path_prefix
+
".pdmodel"
params_path
=
path_prefix
+
".pdiparams"
params_path
=
path_prefix
+
".pdiparams"
# set model_path and params_path in old way for compatible,
# set model_path and params_path in old way for compatible,
# path_prefix represents a directory path.
# path_prefix represents a directory path.
else
:
else
:
model_filename
=
confi
gs
.
get
(
'model_filename'
,
None
)
model_filename
=
kwar
gs
.
get
(
'model_filename'
,
None
)
params_filename
=
confi
gs
.
get
(
'params_filename'
,
None
)
params_filename
=
kwar
gs
.
get
(
'params_filename'
,
None
)
# set model_path
# set model_path
if
model_filename
is
None
:
if
model_filename
is
None
:
model_path
=
os
.
path
.
join
(
path_prefix
,
"__model__"
)
model_path
=
os
.
path
.
join
(
path_prefix
,
"__model__"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录