Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e0a52fd7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0a52fd7
编写于
4月 20, 2021
作者:
W
WeiXin
提交者:
GitHub
4月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
save/load program (#32336)
上级
f6f59e50
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
77 addition
and
40 deletion
+77
-40
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
+24
-0
python/paddle/framework/io.py
python/paddle/framework/io.py
+53
-40
未找到文件。
python/paddle/fluid/tests/unittests/test_paddle_save_load.py
浏览文件 @
e0a52fd7
...
...
@@ -447,5 +447,29 @@ class TestSaveLoad(unittest.TestCase):
paddle
.
load
(
"test_paddle_save_load.linear"
)
class
TestSaveLoadProgram
(
unittest
.
TestCase
):
def
test_save_load_program
(
self
):
paddle
.
enable_static
()
with
new_program_scope
():
layer
=
LinearNet
()
data
=
paddle
.
static
.
data
(
name
=
'x_static_save'
,
shape
=
(
None
,
IMAGE_SIZE
),
dtype
=
'float32'
)
y_static
=
layer
(
data
)
main_program
=
paddle
.
static
.
default_main_program
()
startup_program
=
paddle
.
static
.
default_startup_program
()
origin_main
=
main_program
.
desc
.
serialize_to_string
()
origin_startup
=
startup_program
.
desc
.
serialize_to_string
()
path1
=
"test_paddle_save_load_program/main_program.pdmodel"
path2
=
"test_paddle_save_load_program/startup_program.pdmodel"
paddle
.
save
(
main_program
,
path1
)
paddle
.
save
(
startup_program
,
path2
)
with
new_program_scope
():
load_main
=
paddle
.
load
(
path1
).
desc
.
serialize_to_string
()
load_startup
=
paddle
.
load
(
path2
).
desc
.
serialize_to_string
()
self
.
assertTrue
(
origin_main
==
load_main
)
self
.
assertTrue
(
origin_startup
==
load_startup
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/framework/io.py
浏览文件 @
e0a52fd7
...
...
@@ -33,7 +33,7 @@ from paddle.fluid import core
from
paddle.fluid.io
import
_unpack_saved_dict
,
_pack_loaded_dict
,
_pickle_loads_mac
from
paddle.fluid.io
import
_legacy_save
as
_legacy_static_save
from
paddle.fluid.framework
import
Variable
,
_varbase_creator
,
_dygraph_tracer
,
in_dygraph_mode
,
ParamBase
,
_current_expected_place
from
paddle.fluid.framework
import
Variable
,
_varbase_creator
,
_dygraph_tracer
,
in_dygraph_mode
,
ParamBase
,
_current_expected_place
,
Program
from
paddle.fluid.dygraph.jit
import
_SaveLoadConfig
from
paddle.fluid.dygraph.io
import
_construct_program_holders
,
_construct_params_and_buffers
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
,
INFER_PARAMS_INFO_SUFFIX
...
...
@@ -453,8 +453,11 @@ def save(obj, path, protocol=2, **configs):
warnings
.
warn
(
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)
if
_use_legacy
(
obj
):
if
isinstance
(
obj
,
Program
):
obj
.
desc
.
flush
()
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
obj
.
desc
.
serialize_to_string
())
elif
_use_legacy
(
obj
):
if
in_dygraph_mode
():
_legacy_save
(
obj
,
path
,
protocol
)
else
:
...
...
@@ -627,46 +630,56 @@ def load(path, **configs):
if
os
.
path
.
isfile
(
path
):
config
=
_parse_load_config
(
configs
)
with
open
(
path
,
'rb'
)
as
f
:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
load_result
=
_pickle_loads_mac
(
path
,
f
)
else
:
load_result
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if
isinstance
(
load_result
,
dict
):
if
isinstance
(
load_result
,
dict
):
load_result
=
_pack_loaded_dict
(
load_result
)
# paddle2.0: paddle.save/load
if
"StructuredToParameterName@@"
in
load_result
:
if
six
.
PY2
:
exception_type
=
KeyError
else
:
exception_type
=
pickle
.
UnpicklingError
try
:
with
open
(
path
,
'rb'
)
as
f
:
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
if
sys
.
platform
==
'darwin'
and
sys
.
version_info
.
major
==
3
:
load_result
=
_pickle_loads_mac
(
path
,
f
)
else
:
load_result
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
for
key
in
load_result
[
"StructuredToParameterName@@"
]:
load_result
[
key
]
=
_ndarray_to_tensor
(
load_result
[
key
],
config
.
return_numpy
)
# TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
if
isinstance
(
load_result
,
dict
):
if
isinstance
(
load_result
,
dict
):
load_result
=
_pack_loaded_dict
(
load_result
)
# paddle2.0: paddle.save/load
if
"StructuredToParameterName@@"
in
load_result
:
for
key
in
load_result
[
"StructuredToParameterName@@"
]:
load_result
[
key
]
=
_ndarray_to_tensor
(
load_result
[
key
],
config
.
return_numpy
)
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
load_result
:
del
load_result
[
"StructuredToParameterName@@"
]
else
:
# paddle2.1 static.save/load
for
key
in
load_result
:
load_result
[
key
]
=
_ndarray_to_tensor
(
load_result
[
key
],
config
.
return_numpy
)
if
not
config
.
keep_name_table
and
"StructuredToParameterName@@"
in
load_result
:
del
load_result
[
"StructuredToParameterName@@"
]
else
:
# paddle2.1 static.save/load
for
key
in
load_result
:
load_result
[
key
]
=
_ndarray_to_tensor
(
load_result
[
key
],
config
.
return_numpy
)
else
:
# TODO(weixin): support complex objects such as layer.
# If `obj` is any object, the judgment condition should be more precise.
if
_transformed_from_lodtensor
(
load_result
):
load_result
=
_ndarray_to_tensor
(
load_result
,
config
.
return_numpy
)
elif
_transformed_from_varbase
(
load_result
):
load_result
=
_tuple_to_tensor
(
load_result
,
config
.
return_numpy
)
else
:
raise
NotImplementedError
(
'Only support tensor and state_dict, but received {}.'
.
format
(
type
(
load_result
)))
# TODO(weixin): support complex objects such as layer.
# If `obj` is any object, the judgment condition should be more precise.
if
_transformed_from_lodtensor
(
load_result
):
load_result
=
_ndarray_to_tensor
(
load_result
,
config
.
return_numpy
)
elif
_transformed_from_varbase
(
load_result
):
load_result
=
_tuple_to_tensor
(
load_result
,
config
.
return_numpy
)
else
:
raise
NotImplementedError
(
'Only support tensor and state_dict, but received {}.'
.
format
(
type
(
load_result
)))
except
exception_type
:
with
open
(
path
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
program
=
Program
.
parse_from_string
(
program_desc_str
)
return
program
else
:
load_result
=
_legacy_load
(
path
,
**
configs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录