Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
95545f76
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
95545f76
编写于
7月 09, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint api optimized
上级
436bb450
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
104 addition
and
63 deletion
+104
-63
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+63
-41
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+41
-22
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
95545f76
...
...
@@ -25,9 +25,7 @@ __all__ = [
'save_vars'
,
'save_params'
,
'save_persistables'
,
'load_vars'
,
'load_params'
,
'load_persistables'
,
'save_inference_model'
,
'load_inference_model'
,
'get_inference_program'
,
'save_checkpoint'
,
'load_checkpoint'
,
'clean_checkpoint'
,
'load_persist_vars_without_grad'
,
'load_lookup_table_vars'
,
'save_persist_vars_without_grad'
,
'get_latest_checkpoint_serial'
'clean_checkpoint'
]
...
...
@@ -805,11 +803,11 @@ CHECKPOINT_SEPARATOR = "_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
main_program
,
trainer_args
=
None
,
main_program
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
ps
_endpoint_list
=
None
):
ps
erver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
...
...
@@ -836,16 +834,16 @@ def save_checkpoint(executor,
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program
|None
): The program whose checkpoint variables will
be saved.
If it is None, the default main program will be used.
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps
_endpoint_list
(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps
_endpoint_list
by
ps
erver_endpoints
(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps
erver_endpoints
by
distribute arguments.
Returns:
...
...
@@ -873,11 +871,13 @@ def save_checkpoint(executor,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
ps
_endpoint_list
= ps_endpoints)
ps
erver_endpoints
= ps_endpoints)
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
assert
checkpoint_dir
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
...
...
@@ -885,22 +885,28 @@ def save_checkpoint(executor,
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
serial
=
_
get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
_
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
_
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
ps
_endpoint_list
:
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
ps_endpoint_list
)
if
is_chief
and
lookup_table
and
ps
erver_endpoints
:
_
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
serial
,
main_program
):
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
role_id
=
0
,
is_trainer
=
True
,
load_trainer_args
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
...
...
@@ -924,13 +930,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None.
Examples:
...
...
@@ -951,14 +960,27 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
r
aise
ValueError
(
"'serial' should not be None or <0 "
)
r
eturn
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
...
...
@@ -979,7 +1001,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os
.
rmdir
(
checkpoint_dir
)
def
load_persist_vars_without_grad
(
executor
,
def
_
load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
...
...
@@ -1011,10 +1033,10 @@ def load_persist_vars_without_grad(executor,
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe,
fluid.io.
_
load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function
# In this example, `
_
load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
...
...
@@ -1031,7 +1053,7 @@ def load_persist_vars_without_grad(executor,
filename
=
None
)
def
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
def
_
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
...
...
@@ -1050,11 +1072,11 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/
__model__
"
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
fluid.io.
_
load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
...
...
@@ -1081,7 +1103,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
executor
.
run
(
load_prog
)
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
def
_
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
...
...
@@ -1108,10 +1130,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe,
fluid.io.
_
save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function
# In this example, `
_
save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
...
...
@@ -1127,7 +1149,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success
(
cur_dir
)
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
def
_
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
...
...
@@ -1156,7 +1178,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
fluid.io.
_
save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
...
...
@@ -1175,7 +1197,7 @@ def save_pserver_vars_by_notify(executor, dirname, lookup_table,
executor
.
run
(
checkpoint_notify_program
)
def
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
def
_
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
...
...
@@ -1187,7 +1209,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
_write_success
(
cur_dir
)
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
_
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
...
...
@@ -1208,7 +1230,7 @@ def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
fluid.io.
_
load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
...
...
@@ -1339,7 +1361,7 @@ def _write_success(dirname):
f
.
write
(
now
)
def
get_latest_checkpoint_serial
(
checkpoint_dir
):
def
_
get_latest_checkpoint_serial
(
checkpoint_dir
):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
...
...
python/paddle/fluid/trainer.py
浏览文件 @
95545f76
...
...
@@ -277,31 +277,14 @@ class Trainer(object):
exe
.
run
(
self
.
startup_program
)
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
place
)
io
.
load_checkpoint
(
exe
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
epoch_id
,
step_id
=
io
.
load_trainer_args
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
trainer_id
,
self
.
_get_checkpoint_load_args
())
self
.
checkpoint_cfg
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint_cfg
.
step_id
=
int
(
step_id
)
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
io
.
load_lookup_table_vars
(
exe
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
startup_program
,
self
.
checkpoint_cfg
.
pserver_id
,
self
.
checkpoint_cfg
.
lookup_table_name
)
self
.
_load_checkpoint
()
if
param_path
and
os
.
path
.
isdir
(
param_path
):
# load params from param_path into scope
io
.
load_persist_vars_without_grad
(
exe
,
dirname
=
param_path
,
program
=
self
.
startup_program
)
io
.
load_persistables
(
executor
=
exe
,
dirname
=
param_path
,
main_program
=
self
.
startup_program
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
...
...
@@ -580,6 +563,42 @@ class Trainer(object):
main_program
=
self
.
train_program
,
max_num_checkpoints
=
self
.
checkpoint_cfg
.
max_num_checkpoints
)
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
io
.
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
io
.
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
load_trainer_args
)
if
len
(
trainer_args
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
io
.
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
def
build_feed_var_list
(
program
,
feed_order
):
if
not
isinstance
(
program
,
framework
.
Program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录