Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
3c334bd7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c334bd7
编写于
7月 20, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
1dd14a70
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
52 addition
and
49 deletion
+52
-49
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+52
-49
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
3c334bd7
...
@@ -560,6 +560,9 @@ class Trainer(object):
...
@@ -560,6 +560,9 @@ class Trainer(object):
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
print
(
"_save_checkpoint ..."
)
exe
=
executor
.
Executor
(
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
save_checkpoint
(
save_checkpoint
(
executor
=
exe
,
executor
=
exe
,
...
@@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_"
...
@@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_"
def
save_checkpoint
(
executor
,
def
save_checkpoint
(
executor
,
checkpoint_dir
,
checkpoint_dir
,
trainer_id
,
main_program
=
None
,
main_program
,
trainer_id
=
0
,
trainer_args
=
None
,
save_trainer_args
=
None
,
max_num_checkpoints
=
3
,
save_lookup_table
=
None
,
save_lookup_table
=
None
,
pserver_endpoints
=
None
):
pserver_endpoints
=
None
,
max_num_checkpoints
=
3
):
"""
"""
This function filters out all checkpoint variables from the give
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
main_program and then saves these variables to the `checkpoint_dir`
...
@@ -735,21 +738,18 @@ def save_checkpoint(executor,
...
@@ -735,21 +738,18 @@ def save_checkpoint(executor,
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
is_chief
=
trainer_id
==
0
if
save_trainer_args
is
not
None
:
_save_trainer_args
(
cur_dir
,
trainer_id
,
save_trainer_args
)
if
is_chief
:
if
is_chief
:
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
_save_persistable_vars
(
executor
,
cur_dir
,
main_program
)
_save_persistable_vars
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
save_lookup_table
and
pserver_endpoints
:
if
is_chief
and
save_lookup_table
and
pserver_endpoints
:
...
@@ -764,7 +764,7 @@ def load_checkpoint(executor,
...
@@ -764,7 +764,7 @@ def load_checkpoint(executor,
main_program
=
None
,
main_program
=
None
,
role_id
=
0
,
role_id
=
0
,
is_trainer
=
True
,
is_trainer
=
True
,
load_models
=
Tru
e
,
load_models
=
Fals
e
,
load_trainer_args
=
None
,
load_trainer_args
=
None
,
load_slice_up_vars
=
None
,
load_slice_up_vars
=
None
,
load_lookup_table
=
None
):
load_lookup_table
=
None
):
...
@@ -827,6 +827,10 @@ def load_checkpoint(executor,
...
@@ -827,6 +827,10 @@ def load_checkpoint(executor,
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
return
return
if
load_trainer_args
:
if
load_trainer_args
:
print
(
"checkpoint_dir: {}, role_id: {}, load_trainer_args: {}"
.
format
(
checkpoint_dir
,
role_id
,
load_trainer_args
))
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
load_trainer_args
)
load_trainer_args
)
return
trainer_args_ret
return
trainer_args_ret
...
@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
...
@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
: param checkpoint_dir
: param checkpoint_dir
"""
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
"""
...
@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
...
@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
"""
"""
serial
=
_get_dir_serial
(
cur_dir
)
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
if
serial
==
-
1
or
\
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
return
-
1
success_path
=
os
.
path
.
join
(
success_path
=
os
.
path
.
join
(
...
@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
...
@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
if
os
.
path
.
isfile
(
success_path
):
if
os
.
path
.
isfile
(
success_path
):
return
serial
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
current_dir
=
-
1
if
not
checkpoint_dir
or
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
current_dir
dirs
=
os
.
listdir
(
checkpoint_dir
)
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录