Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
aa42bc25
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aa42bc25
编写于
8月 10, 2022
作者:
A
Aurelius84
提交者:
GitHub
8月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[BugFix]Fix save/load_inference_model API BUG while program contains no param (#45038)
上级
b1e33bea
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
2 deletion
+41
-2
python/paddle/fluid/tests/unittests/test_static_save_load.py
python/paddle/fluid/tests/unittests/test_static_save_load.py
+29
-0
python/paddle/static/io.py
python/paddle/static/io.py
+12
-2
未找到文件。
python/paddle/fluid/tests/unittests/test_static_save_load.py
浏览文件 @
aa42bc25
...
...
@@ -1626,6 +1626,35 @@ class TestStaticSaveLoadPickle(unittest.TestCase):
np
.
testing
.
assert_array_equal
(
new_t
,
base_t
)
class
TestSaveLoadInferenceModel
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
self
.
model_path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
'no_params'
)
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
def
test_no_params
(
self
):
main_program
=
framework
.
Program
()
with
framework
.
program_guard
(
main_program
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
10
,
10
],
dtype
=
'float32'
)
y
=
x
+
x
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
paddle
.
static
.
save_inference_model
(
self
.
model_path
,
[
x
],
[
y
],
exe
)
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
paddle
.
static
.
load_inference_model
(
self
.
model_path
,
exe
))
self
.
assertEqual
(
feed_target_names
,
[
'x'
])
self
.
assertEqual
(
fetch_targets
[
0
].
shape
,
(
10
,
10
))
ops
=
[
op
.
type
for
op
in
inference_program
.
block
(
0
).
ops
]
self
.
assertEqual
(
ops
,
[
'feed'
,
'elementwise_add'
,
'scale'
,
'fetch'
])
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/static/io.py
浏览文件 @
aa42bc25
...
...
@@ -542,7 +542,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
save_to_file
(
model_path
,
program_bytes
)
# serialize and save params
params_bytes
=
_serialize_persistables
(
program
,
executor
)
save_to_file
(
params_path
,
params_bytes
)
# program may not contain any parameter and just compute operation
if
params_bytes
is
not
None
:
save_to_file
(
params_path
,
params_bytes
)
@
static_only
...
...
@@ -660,6 +662,12 @@ def deserialize_persistables(program, data, executor):
check_vars
.
append
(
var
)
load_var_map
[
var_copy
.
name
]
=
var_copy
if
data
is
None
:
assert
len
(
origin_shape_map
)
==
0
,
"Required 'data' shall be not None if program contains parameter, but received 'data' is None."
return
# append load_combine op to load parameters,
load_var_list
=
[]
for
name
in
sorted
(
load_var_map
.
keys
()):
...
...
@@ -849,7 +857,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
params_filename
=
os
.
path
.
basename
(
params_path
)
# load params data
params_path
=
os
.
path
.
join
(
load_dirname
,
params_filename
)
params_bytes
=
load_from_file
(
params_path
)
params_bytes
=
None
if
os
.
path
.
exists
(
params_path
):
params_bytes
=
load_from_file
(
params_path
)
# deserialize bytes to program
program
=
deserialize_program
(
program_bytes
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录