Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3c334bd7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c334bd7
编写于
7月 20, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
1dd14a70
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
52 addition
and
49 deletion
+52
-49
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+52
-49
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
3c334bd7
...
...
@@ -73,7 +73,7 @@ class BeginStepEvent(object):
self
.
step
=
step_id
self
.
fetch_metrics
=
True
"""
If fetch_metrics is true, the metrics will be fetched at the
If fetch_metrics is true, the metrics will be fetched at the
EndStepEvent. Default is True.
"""
...
...
@@ -560,6 +560,9 @@ class Trainer(object):
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
print
(
"_save_checkpoint ..."
)
exe
=
executor
.
Executor
(
self
.
place
)
save_checkpoint
(
executor
=
exe
,
...
...
@@ -604,7 +607,7 @@ class Trainer(object):
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args_ret
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args_ret
[
1
])
# Pserver Load
# Pserver Load
else
:
# load slice_vars
if
self
.
slice_vars
!=
None
and
len
(
self
.
slice_vars
)
!=
0
:
...
...
@@ -661,22 +664,22 @@ CHECKPOINT_SEPARATOR = "_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
main_program
=
None
,
trainer_id
=
0
,
save_trainer_args
=
None
,
save_lookup_table
=
None
,
pserver_endpoints
=
None
):
pserver_endpoints
=
None
,
max_num_checkpoints
=
3
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
...
...
@@ -688,21 +691,21 @@ def save_checkpoint(executor,
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
save_lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
...
...
@@ -735,21 +738,18 @@ def save_checkpoint(executor,
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
,
True
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
is_chief
=
trainer_id
==
0
if
save_trainer_args
is
not
None
:
_save_trainer_args
(
cur_dir
,
trainer_id
,
save_trainer_args
)
if
is_chief
:
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
_save_persistable_vars
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
save_lookup_table
and
pserver_endpoints
:
...
...
@@ -764,7 +764,7 @@ def load_checkpoint(executor,
main_program
=
None
,
role_id
=
0
,
is_trainer
=
True
,
load_models
=
Tru
e
,
load_models
=
Fals
e
,
load_trainer_args
=
None
,
load_slice_up_vars
=
None
,
load_lookup_table
=
None
):
...
...
@@ -774,8 +774,8 @@ def load_checkpoint(executor,
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
...
...
@@ -827,6 +827,10 @@ def load_checkpoint(executor,
_load_persistable_vars
(
executor
,
checkpoint_dir
,
main_program
,
True
)
return
if
load_trainer_args
:
print
(
"checkpoint_dir: {}, role_id: {}, load_trainer_args: {}"
.
format
(
checkpoint_dir
,
role_id
,
load_trainer_args
))
trainer_args_ret
=
_load_trainer_args
(
checkpoint_dir
,
role_id
,
load_trainer_args
)
return
trainer_args_ret
...
...
@@ -842,9 +846,9 @@ def load_checkpoint(executor,
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
clean the checkpoint dir, when the train exits normally,
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
...
...
@@ -954,7 +958,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars):
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
...
...
@@ -1005,7 +1009,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
def
_save_persistable_vars
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
...
...
@@ -1034,7 +1038,7 @@ def _save_persistable_vars(executor, dirname, program):
# In this example, `_save_persistable_vars` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir
=
_get_model_dir
(
dirname
)
...
...
@@ -1053,7 +1057,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
...
...
@@ -1061,13 +1065,13 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
...
...
@@ -1078,7 +1082,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
...
...
@@ -1110,7 +1114,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
def
_load_trainer_args
(
checkpoint_dir
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
...
...
@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
: param checkpoint_dir
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
...
...
@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
"""
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
if
serial
==
-
1
or
\
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
...
...
@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
if
os
.
path
.
isfile
(
success_path
):
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
if
not
checkpoint_dir
or
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
current_dir
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录