Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
01eff171
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
01eff171
编写于
4月 19, 2019
作者:
L
lujun
提交者:
GitHub
4月 19, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16983 from junjun315/fix_dy_sl
Fix dy sl bug
上级
c0170255
9d2f7d76
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
42 addition
and
46 deletion
+42
-46
python/paddle/fluid/dygraph/checkpoint.py
python/paddle/fluid/dygraph/checkpoint.py
+29
-38
python/paddle/fluid/dygraph/layers.py
python/paddle/fluid/dygraph/layers.py
+9
-1
python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py
...addle/fluid/tests/unittests/test_imperative_checkpoint.py
+4
-7
未找到文件。
python/paddle/fluid/dygraph/checkpoint.py
浏览文件 @
01eff171
...
...
@@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
_save_var_to_file
(
vardict
,
dirname
,
filename
)
def
load_persistables
(
vardict
,
dirname
,
filename
=
Non
e
):
def
load_persistables
(
dirnam
e
):
"""
This function trys to load persistable variables from the folder
`dirname` or the file `filename`.
...
...
@@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
the file name.
Args:
vardict(dict of Parameters): The parameters will be loaded.
dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
dict: The parameter-dict resumed from file
...
...
@@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
param_1 = param_dict['PtbModel_0.w_1']
"""
if
isinstance
(
vardict
,
collections
.
OrderedDict
):
return
_load_var_from_file
(
vardict
,
dirname
,
filename
)
return
{}
return
_load_var_from_file
(
dirname
)
def
_save_var_to_file
(
stat_dict
,
file_dir
,
file_name
):
...
...
@@ -139,42 +132,40 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
})
def
_load_var_from_file
(
stat_dict
,
file_dir
,
file_name
):
def
_load_var_from_file
(
file_dir
):
def
walk_filename
(
file_dir
):
base_path
=
os
.
path
.
join
(
file_dir
)
var_name_list
=
[]
if
os
.
path
.
exists
(
base_path
):
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
base_path
):
pt
=
dirpath
.
replace
(
base_path
,
""
,
1
)
if
pt
.
startswith
(
"/"
)
or
pt
.
startswith
(
"
\\
"
):
pt
=
pt
[
1
:]
for
fth_name
in
filenames
:
if
fth_name
[
0
]
!=
'.'
:
name_path
=
os
.
path
.
join
(
pt
,
fth_name
)
if
"
\\
"
in
name_path
:
name_path
=
name_path
.
replace
(
"
\\
"
,
"/"
)
var_name_list
.
append
(
name_path
)
return
var_name_list
load_block
=
default_main_program
().
global_block
()
load_var_map
=
{}
for
var_key
,
each_var
in
stat_dict
.
items
():
assert
isinstance
(
each_var
,
Variable
)
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
file_name
is
None
:
file_var_list
=
walk_filename
(
file_dir
)
for
var_name
in
file_var_list
:
new_var
=
Variable
(
block
=
load_block
,
name
=
var_name
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
os
.
path
.
normpath
(
each
_var
.
name
))
os
.
path
.
normpath
(
new
_var
.
name
))
})
load_var_map
[
new_var
.
name
]
=
new_var
if
file_name
is
not
None
:
load_var_list
=
[]
for
name
in
sorted
(
load_var_map
.
keys
()):
load_var_list
.
append
(
load_var_map
[
name
])
load_block
.
append_op
(
type
=
'load_combine'
,
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
file_dir
,
os
.
path
.
normpath
(
file_name
))
})
for
res_var
in
load_var_list
:
load_var_map
[
res_var
.
name
]
=
res_var
return
load_var_map
...
...
python/paddle/fluid/dygraph/layers.py
浏览文件 @
01eff171
...
...
@@ -45,6 +45,7 @@ class Layer(core.Layer):
self
.
_dtype
=
dtype
self
.
_parameters
=
collections
.
OrderedDict
()
self
.
_sub_layers
=
collections
.
OrderedDict
()
self
.
_loaddict_holder
=
collections
.
OrderedDict
()
self
.
_helper
=
LayerObjectHelper
(
self
.
_full_name
)
...
...
@@ -193,6 +194,9 @@ class Layer(core.Layer):
"""
assert
isinstance
(
parameter
,
framework
.
Parameter
)
self
.
_parameters
[
name
]
=
parameter
if
parameter
.
name
in
self
.
_loaddict_holder
:
self
.
_parameters
[
name
]
=
self
.
_loaddict_holder
[
parameter
.
name
]
parameter
=
self
.
_loaddict_holder
[
parameter
.
name
]
return
parameter
def
__getattr__
(
self
,
name
):
...
...
@@ -207,6 +211,9 @@ class Layer(core.Layer):
if
params
is
None
:
raise
ValueError
(
"super(YourLayer, self).__init__() should be called first"
)
if
value
.
name
in
self
.
_loaddict_holder
:
params
[
name
]
=
self
.
_loaddict_holder
[
value
.
name
]
else
:
params
[
name
]
=
value
elif
isinstance
(
value
,
core
.
Layer
):
layers
=
self
.
__dict__
.
get
(
'_sub_layers'
,
None
)
...
...
@@ -244,6 +251,7 @@ class Layer(core.Layer):
return
destination
def
load_dict
(
self
,
stat_dict
,
include_sublayers
=
True
):
self
.
_loaddict_holder
=
stat_dict
for
name
,
item
in
self
.
__dict__
.
get
(
'_parameters'
,
None
).
items
():
if
item
.
name
in
stat_dict
:
var
=
item
.
_ivar
.
value
()
...
...
python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py
浏览文件 @
01eff171
...
...
@@ -142,14 +142,11 @@ class TestDygraphCheckpoint(unittest.TestCase):
for
param
in
mnist
.
parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
numpy
()
mnist
.
load_dict
(
fluid
.
dygraph
.
load_persistables
(
mnist
.
state_dict
(),
"save_dir"
))
restore
=
mnist
.
parameters
()
restore
=
fluid
.
dygraph
.
load_persistables
(
"save_dir"
)
mnist
.
load_dict
(
restore
)
self
.
assertEqual
(
len
(
dy_param_init_value
),
len
(
restore
))
for
value
in
restore
:
for
ky
,
value
in
restore
.
items
()
:
self
.
assertTrue
(
np
.
allclose
(
value
.
numpy
(),
dy_param_init_value
[
value
.
name
]))
...
...
@@ -158,7 +155,7 @@ class TestDygraphCheckpoint(unittest.TestCase):
step
+=
1
if
step
>
2
0
:
if
step
>
1
0
:
break
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录