Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d96b4427
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d96b4427
编写于
5月 23, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename checkpoint folder to checkpoint_serial
上级
9d985340
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
39 addition
and
27 deletion
+39
-27
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+39
-27
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
d96b4427
...
...
@@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
dirname
=
None
,
checkpoint_dir
=
None
,
max_num_checkpoints
=
3
,
save_interval_secs
=
600
,
main_program
=
None
):
...
...
@@ -466,26 +468,27 @@ def save_checkpoint(executor,
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval
time between two save_checkpoint must great than or equal to
save_interval_secs.
The interval
between two saved checkpoints must greater than
save_interval_secs.
:param dirname
:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_secs
:param save_
interval_
secs
:param main_program
"""
if
dirname
is
None
:
dirname
=
os
.
getcwd
()
if
checkpoint_dir
is
None
:
checkpoint_dir
=
os
.
getcwd
()
if
not
os
.
path
.
isdir
(
dirname
):
os
.
makedirs
(
dirname
)
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
os
.
makedirs
(
checkpoint_dir
)
serial
=
_get_lastest_checkpoint_dir
(
dirname
)
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
if
serial
>=
0
and
not
_interval_secs_exceed
(
os
.
path
.
join
(
dirname
,
str
(
serial
)
),
save_interval_secs
):
_get_serial_dir
(
serial
,
checkpoint_dir
),
save_interval_secs
):
return
serial
=
serial
+
1
cur_dir
=
os
.
path
.
join
(
dirname
,
str
(
serial
)
)
serial
+=
1
cur_dir
=
_get_serial_dir
(
serial
,
checkpoint_dir
)
save_vars
(
executor
,
...
...
@@ -495,27 +498,28 @@ def save_checkpoint(executor,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
_lru_delete
(
dirname
,
max_num_checkpoints
)
_lru_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
dirname
=
None
,
main_program
=
None
):
def
load_checkpoint
(
executor
,
checkpoint_dir
=
None
,
main_program
=
None
):
"""
Load checkpoint from a directory by executor,
it will find
latest
checkpoint file and load it auto.
it will find
the most recent saved
checkpoint file and load it auto.
:param executor
:param
dirname
:param
checkpoint_dir
:param main_program
"""
if
dirname
is
None
:
dirname
=
os
.
getcwd
()
if
checkpoint_dir
is
None
:
checkpoint_dir
=
os
.
getcwd
()
serial
=
_get_lastest_checkpoint_dir
(
dirname
)
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
if
serial
<
0
:
return
cur_dir
=
os
.
path
.
join
(
dirname
,
str
(
serial
))
cur_dir
=
_get_serial_dir
(
serial
,
checkpoint_dir
)
load_vars
(
executor
,
...
...
@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
filename
=
None
)
def
_get_serial_dir
(
serial
,
checkpoint_dir
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
return
os
.
path
.
join
(
checkpoint_dir
,
serial_folder
)
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
...
...
@@ -577,7 +586,8 @@ def _write_success(dirname):
"""
success_file
=
os
.
path
.
join
(
dirname
,
SUCCESS_MARK_FILENAME
)
with
open
(
success_file
,
'a'
):
pass
now
=
time
.
ctime
()
success_file
.
write
(
now
)
def
_get_lastest_checkpoint_dir
(
checkpoint_dir
):
...
...
@@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
is _SUCCESS in this dir
"""
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
_
,
serial
=
cur_dir
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
int
(
cur_dir
)
int
(
serial
)
except
ValueError
:
return
-
1
success_path
=
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
,
SUCCESS_MARK_FILENAME
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
serial
,
checkpoint_dir
),
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
int
(
cur_dir
)
return
int
(
serial
)
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录