Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ac8afe18
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ac8afe18
编写于
9月 11, 2020
作者:
C
Chen Weihang
提交者:
GitHub
9月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use structured name in loaded dict (#27242)
上级
5e0dde02
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
19 addition
and
3 deletion
+19
-3
python/paddle/fluid/dygraph/checkpoint.py
python/paddle/fluid/dygraph/checkpoint.py
+14
-1
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+5
-2
未找到文件。
python/paddle/fluid/dygraph/checkpoint.py
浏览文件 @
ac8afe18
...
@@ -25,7 +25,7 @@ import warnings
...
@@ -25,7 +25,7 @@ import warnings
from
..
import
core
from
..
import
core
from
.base
import
guard
from
.base
import
guard
from
paddle.fluid.dygraph.jit
import
SaveLoadConfig
,
deprecate_save_load_configs
from
paddle.fluid.dygraph.jit
import
SaveLoadConfig
,
deprecate_save_load_configs
from
paddle.fluid.dygraph.io
import
_construct_program_holders
,
_construct_params_and_buffers
from
paddle.fluid.dygraph.io
import
_construct_program_holders
,
_construct_params_and_buffers
,
EXTRA_VAR_INFO_FILENAME
__all__
=
[
__all__
=
[
'save_dygraph'
,
'save_dygraph'
,
...
@@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None):
...
@@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None):
para_dict
=
dict
()
para_dict
=
dict
()
for
var_name
in
persistable_var_dict
:
for
var_name
in
persistable_var_dict
:
para_dict
[
var_name
]
=
persistable_var_dict
[
var_name
].
numpy
()
para_dict
[
var_name
]
=
persistable_var_dict
[
var_name
].
numpy
()
# if __variables.info__ exists, we can recover structured_name
var_info_path
=
os
.
path
.
join
(
model_prefix
,
EXTRA_VAR_INFO_FILENAME
)
if
os
.
path
.
exists
(
var_info_path
):
with
open
(
var_info_path
,
'rb'
)
as
f
:
extra_var_info
=
pickle
.
load
(
f
)
structured_para_dict
=
dict
()
for
var_name
in
para_dict
:
structured_name
=
extra_var_info
[
var_name
].
get
(
'structured_name'
,
None
)
assert
structured_name
is
not
None
,
"Cannot find saved variable (%s)'s structured name in saved model."
%
var_name
structured_para_dict
[
structured_name
]
=
para_dict
[
var_name
]
para_dict
=
structured_para_dict
else
:
else
:
# Load state dict by `save_dygraph` save format
# Load state dict by `save_dygraph` save format
para_dict
=
{}
para_dict
=
{}
...
...
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
ac8afe18
...
@@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase):
...
@@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase):
train_layer
.
eval
()
train_layer
.
eval
()
# construct new model
# construct new model
new_layer
=
LinearNet
(
784
,
1
)
new_layer
=
LinearNet
(
784
,
1
)
model_dict
,
_
=
fluid
.
dygraph
.
load_dygraph
(
self
.
model_path
)
orig_state_dict
=
new_layer
.
state_dict
()
new_layer
.
set_dict
(
model_dict
)
load_state_dict
,
_
=
fluid
.
dygraph
.
load_dygraph
(
self
.
model_path
)
for
structured_name
in
orig_state_dict
:
self
.
assertTrue
(
structured_name
in
load_state_dict
)
new_layer
.
set_state_dict
(
load_state_dict
)
new_layer
.
eval
()
new_layer
.
eval
()
# inference & compare
# inference & compare
x
=
fluid
.
dygraph
.
to_variable
(
x
=
fluid
.
dygraph
.
to_variable
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录