Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ad9dfeb0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad9dfeb0
编写于
5月 29, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix and optimize
上级
5f5d6a9d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
162 addition
and
43 deletion
+162
-43
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+117
-36
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+45
-7
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
ad9dfeb0
...
...
@@ -456,40 +456,18 @@ def get_parameter_value_by_name(name, executor, program=None):
return
get_parameter_value
(
var
,
executor
)
def
load_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
save_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
is_chief
=
False
,
trainer_args
=
None
,
main_program
=
None
,
max_num_checkpoints
=
3
):
"""
...
...
@@ -502,22 +480,35 @@ def save_checkpoint(executor,
:param checkpoint_dir
:param main_program
:param max_num_checkpoints
:param is_chief
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
if
trainer_args
and
not
isinstance
(
trainer_args
,
dict
):
raise
TypeError
(
"The type of 'trainer_args' should be dict"
)
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
os
.
makedirs
(
checkpoint_dir
)
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
_write_success
(
cur_dir
)
if
is_chief
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
_lru_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
=
None
):
def
need_load_checkpoint
(
checkpoint_dir
):
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
if
serial
<
0
:
return
None
return
serial
def
load_checkpoint
(
executor
,
checkpoint_dir
,
serial
,
main_program
):
"""
Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto.
...
...
@@ -528,14 +519,17 @@ def load_checkpoint(executor, checkpoint_dir, main_program=None):
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
raise
ValueError
(
"The values of 'checkpoint_dir' or 'serial' should not be None"
)
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
if
serial
is
None
or
serial
<
0
:
raise
ValueError
(
"The values of 'serial' should not be None or <0 "
)
if
serial
<
0
:
r
eturn
if
main_program
is
None
:
r
aise
ValueError
(
"The values of 'main_program'should not be None"
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_model_dir
(
cur_dir
)
load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
...
...
@@ -552,6 +546,68 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os
.
rmdir
(
checkpoint_dir
)
def
load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
nest
=
True
):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
if
nest
:
dirname
=
_get_model_dir
(
dirname
)
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
cur_dir
=
_get_model_dir
(
dirname
)
save_vars
(
executor
,
dirname
=
cur_dir
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
def
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
if
not
isinstance
(
trainer_args
,
dict
):
raise
TypeError
(
"The type of 'trainer_args' should be dict"
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
trainer_args
.
iteritems
():
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
_write_success
(
cur_dir
)
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
if
not
isinstance
(
trainer_args
,
list
):
raise
TypeError
(
"The type of 'trainer_args' should be list"
)
ret_values
=
[]
for
arg
in
trainer_args
:
cur_file
=
os
.
path
.
join
(
cur_dir
,
arg
)
with
open
(
cur_file
,
'r'
)
as
f
:
contents
=
f
.
read
()
ret_values
.
append
(
contents
.
strip
())
return
ret_values
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
...
...
@@ -583,7 +639,31 @@ def _get_dir_serial(dirname):
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
return
os
.
path
.
join
(
dirname
,
serial_folder
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
if
not
os
.
path
.
isdir
(
serial_dir
):
os
.
makedirs
(
serial_dir
)
return
serial_dir
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
if
not
os
.
path
.
isdir
(
model_dir
):
os
.
makedirs
(
model_dir
)
return
model_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
if
not
os
.
path
.
isdir
(
trainer_dir
):
os
.
makedirs
(
trainer_dir
)
return
trainer_dir
def
_lru_delete
(
dirname
,
max_num_checkpoints
=
3
):
...
...
@@ -638,7 +718,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
checkpoint_dir
,
serial
),
SUCCESS_MARK_FILENAME
)
_get_serial_dir
(
checkpoint_dir
,
serial
),
MODEL_DIR
,
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
serial
...
...
python/paddle/fluid/trainer.py
浏览文件 @
ad9dfeb0
...
...
@@ -79,6 +79,9 @@ class CheckpointConfig(object):
else
:
self
.
step_interval
=
step_interval
self
.
epoch_id
=
0
self
.
step_id
=
0
def
check_and_get_place
(
place
):
"""
...
...
@@ -132,6 +135,7 @@ class Trainer(object):
# config for checkpoint
# only chief worker will save variables
self
.
trainer_id
=
0
self
.
chief
=
True
self
.
checkpoint
=
checkpoint_config
if
self
.
checkpoint
and
\
...
...
@@ -139,6 +143,8 @@ class Trainer(object):
raise
TypeError
(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
self
.
load_checkpoint_serial
=
io
.
need_load_checkpoint
(
self
.
checkpoint
.
checkpoint_dir
)
self
.
scope
=
core
.
Scope
()
...
...
@@ -168,15 +174,25 @@ class Trainer(object):
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
self
.
checkpoint
:
if
self
.
load_checkpoint_serial
:
exe
=
executor
.
Executor
(
place
)
io
.
load_checkpoint
(
exe
,
self
.
checkpoint
.
checkpoint_dir
,
self
.
load_checkpoint_serial
,
self
.
startup_program
)
if
param_path
:
epoch_id
,
step_id
=
io
.
load_trainer_args
(
self
.
checkpoint
.
checkpoint_dir
,
self
.
load_checkpoint_serial
,
self
.
trainer_id
,
[
"epoch_id"
,
"step_id"
])
self
.
checkpoint
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint
.
step_id
=
int
(
step_id
)
if
param_path
and
os
.
path
.
isdir
(
param_path
):
# load params from param_path into scope
io
.
load_persist_vars_without_grad
(
exe
,
dirname
=
param_path
,
program
=
self
.
startup_program
)
exe
,
dirname
=
param_path
,
program
=
self
.
startup_program
,
nest
=
False
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
...
...
@@ -333,11 +349,20 @@ class Trainer(object):
self
.
_train_by_any_executor
(
event_handler
,
exe
,
num_epochs
,
reader
)
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
for
epoch_id
in
range
(
num_epochs
):
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)
if
epoch_id
>=
self
.
checkpoint
.
epoch_id
]
for
epoch_id
in
epochs
:
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
if
self
.
__stop
:
self
.
_clean_checkpoint
()
return
if
self
.
checkpoint
and
self
.
checkpoint
.
step_id
>=
step_id
and
self
.
checkpoint
.
epoch_id
==
epoch_id
:
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
event_handler
(
begin_event
)
if
begin_event
.
fetch_metrics
:
...
...
@@ -352,6 +377,7 @@ class Trainer(object):
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
,
metrics
))
self
.
_save_checkpoint
(
epoch_id
,
step_id
)
event_handler
(
EndEpochEvent
(
epoch_id
))
self
.
_clean_checkpoint
()
def
_test_by_executor
(
self
,
reader
,
feed_order
,
fetch_list
):
with
executor
.
scope_guard
(
self
.
scope
):
...
...
@@ -390,17 +416,29 @@ class Trainer(object):
loss_name
=
self
.
train_func_outputs
[
0
].
name
)
return
self
.
_get_parallel_executor
()
def
_clean_checkpoint
(
self
):
if
not
self
.
checkpoint
:
return
io
.
clean_checkpoint
(
checkpoint_dir
=
self
.
checkpoint
.
checkpoint_dir
)
def
_save_checkpoint
(
self
,
epoch_id
,
step_id
):
if
not
self
.
checkpoint
or
not
self
.
chief
:
if
not
self
.
checkpoint
:
return
if
epoch_id
%
self
.
checkpoint
.
epoch_interval
==
0
and
step_id
%
self
.
checkpoint
.
step_interval
==
0
:
trainer_args
=
{}
trainer_args
[
"epoch_id"
]
=
epoch_id
trainer_args
[
"step_id"
]
=
step_id
exe
=
executor
.
Executor
(
self
.
place
)
io
.
save_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint
.
checkpoint_dir
,
max_num_checkpoints
=
self
.
checkpoint
.
max_num_checkpoints
,
main_program
=
self
.
train_program
)
trainer_id
=
self
.
trainer_id
,
is_chief
=
self
.
chief
,
trainer_args
=
trainer_args
,
main_program
=
self
.
train_program
,
max_num_checkpoints
=
self
.
checkpoint
.
max_num_checkpoints
)
def
build_feed_var_list
(
program
,
feed_order
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录