Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1840bd75
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1840bd75
编写于
7月 12, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix_is_taged
上级
f8e83196
72ce4d56
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
773 addition
and
689 deletion
+773
-689
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+1
-586
python/paddle/fluid/tests/unittests/test_checkpoint.py
python/paddle/fluid/tests/unittests/test_checkpoint.py
+0
-75
python/paddle/fluid/tests/unittests/test_reader_reset.py
python/paddle/fluid/tests/unittests/test_reader_reset.py
+116
-0
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+655
-27
未找到文件。
python/paddle/fluid/__init__.py
浏览文件 @
1840bd75
...
...
@@ -65,7 +65,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'io'
,
'initializer'
,
'layers'
,
'transpiler'
'transpiler'
,
'nets'
,
'optimizer'
,
'learning_rate_decay'
,
...
...
python/paddle/fluid/io.py
浏览文件 @
1840bd75
...
...
@@ -24,10 +24,7 @@ from . import core
__all__
=
[
'save_vars'
,
'save_params'
,
'save_persistables'
,
'load_vars'
,
'load_params'
,
'load_persistables'
,
'save_inference_model'
,
'load_inference_model'
,
'get_inference_program'
,
'save_checkpoint'
,
'load_checkpoint'
,
'clean_checkpoint'
,
'load_persist_vars_without_grad'
,
'load_lookup_table_vars'
,
'save_persist_vars_without_grad'
,
'get_latest_checkpoint_serial'
'get_inference_program'
]
...
...
@@ -794,588 +791,6 @@ def get_parameter_value_by_name(name, executor, program=None):
return
get_parameter_value
(
var
,
executor
)
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
LOOKUP_TABLE_DIR
=
"__lookup_table__"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
trainer_args
=
None
,
main_program
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
ps_endpoint_list
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program|None): The program whose checkpoint variables will
be saved. If it is None, the default main program will be used.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
assert
checkpoint_dir
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
ps_endpoint_list
:
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
ps_endpoint_list
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
serial
,
main_program
):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `serial` is None or `serial` is less than 0.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
fluid.io.load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
serial
is
None
or
serial
<
0
:
raise
ValueError
(
"'serial' should not be None or <0 "
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
if
delete_dir
and
not
os
.
listdir
(
checkpoint_dir
):
os
.
rmdir
(
checkpoint_dir
)
def
load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if
has_model_dir
:
dirname
=
_get_model_dir
(
dirname
)
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
lookup_table_var
=
var
break
assert
lookup_table_var
is
not
None
lookup_table_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
table_file
=
table_name
+
CHECKPOINT_SEPARATOR
+
str
(
pserver_id
)
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
lookup_table_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
lookup_table_dir
,
table_file
)})
executor
.
run
(
load_prog
)
def
save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
fluid.io.save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir
=
_get_model_dir
(
dirname
)
save_vars
(
executor
,
dirname
=
cur_dir
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
checkpoint_notify_program
=
Program
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
def
save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
trainer_args
.
iteritems
():
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
_write_success
(
cur_dir
)
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
ret_values
=
[]
for
arg
in
trainer_args
:
cur_file
=
os
.
path
.
join
(
cur_dir
,
arg
)
with
open
(
cur_file
,
'r'
)
as
f
:
contents
=
f
.
read
()
ret_values
.
append
(
contents
.
strip
())
return
ret_values
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
RAW
:
return
False
# @GRAD are named for gradient variables, checkpoint will not save it.
if
"@GRAD"
in
var
.
name
:
return
False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if
".trainer_"
in
var
.
name
:
return
False
# .block is named for distribute train variables, checkpoint will not save it.
if
".block"
in
var
.
name
:
return
False
return
var
.
persistable
def
_make_chekcpoint_dirs
(
dirs
):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert
dirs
is
not
None
if
os
.
path
.
isfile
(
dirs
):
raise
OSError
(
errno
.
ENOTDIR
,
"dirs path shoule be a Directory."
,
dirs
)
if
not
os
.
path
.
isdir
(
dirs
):
try
:
os
.
makedirs
(
dirs
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
EEXIST
:
raise
err
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
_make_chekcpoint_dirs
(
model_dir
)
return
model_dir
def
_get_lookuptable_dir
(
dirname
):
lookuptable_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
_make_chekcpoint_dirs
(
lookuptable_dir
)
return
lookuptable_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
_make_chekcpoint_dirs
(
trainer_dir
)
return
trainer_dir
def
_scroll_delete
(
dirname
,
max_num_checkpoints
=
3
):
dirs
=
os
.
listdir
(
dirname
)
serial_map
=
{}
for
serial
in
dirs
:
serial_num
=
_get_dir_serial
(
serial
)
serial_map
[
serial_num
]
=
serial
if
len
(
serial_map
.
keys
())
<=
max_num_checkpoints
:
return
serials
=
serial_map
.
keys
()
serials
.
sort
(
reverse
=
True
)
serials
=
serials
[
max_num_checkpoints
:]
for
serial
in
serials
:
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
try
:
shutil
.
rmtree
(
cur_dir
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
ENOENT
:
raise
err
def
_write_success
(
dirname
):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file
=
os
.
path
.
join
(
dirname
,
SUCCESS_MARK_FILENAME
)
with
open
(
success_file
,
'a'
)
as
f
:
now
=
time
.
ctime
()
f
.
write
(
now
)
def
get_latest_checkpoint_serial
(
checkpoint_dir
):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
is _SUCCESS in this dir
"""
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
checkpoint_dir
,
serial
),
MODEL_DIR
,
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
if
success_num
>
current_dir
:
current_dir
=
success_num
return
current_dir
def
get_test_program
(
filelist
,
program
=
None
,
startup_program
=
None
):
"""
Transpile current train program to a program to read test dataset
...
...
python/paddle/fluid/tests/unittests/test_checkpoint.py
已删除
100644 → 0
浏览文件 @
f8e83196
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
unittest
import
os
import
tempfile
class
TestCheckpoint
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dirname
=
tempfile
.
mktemp
()
self
.
max_num_checkpoints
=
3
self
.
epoch_interval
=
1
self
.
step_interval
=
1
self
.
trainer_id
=
0
self
.
chief
=
self
.
trainer_id
==
0
self
.
place
=
fluid
.
CPUPlace
()
self
.
epoch_id
=
100
self
.
step_id
=
20
def
test_checkpoint
(
self
):
self
.
save_checkpoint
()
serial
=
fluid
.
io
.
get_latest_checkpoint_serial
(
self
.
dirname
)
self
.
assertTrue
(
serial
>=
0
)
trainer_args
=
[
"epoch_id"
,
"step_id"
]
epoch_id
,
step_id
=
fluid
.
io
.
load_trainer_args
(
self
.
dirname
,
serial
,
self
.
trainer_id
,
trainer_args
)
self
.
assertEqual
(
self
.
step_id
,
int
(
step_id
))
self
.
assertEqual
(
self
.
epoch_id
,
int
(
epoch_id
))
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
program
):
exe
=
fluid
.
Executor
(
self
.
place
)
fluid
.
io
.
load_checkpoint
(
exe
,
self
.
dirname
,
serial
,
program
)
fluid
.
io
.
clean_checkpoint
(
self
.
dirname
,
delete_dir
=
True
)
self
.
assertFalse
(
os
.
path
.
isdir
(
self
.
dirname
))
def
save_checkpoint
(
self
):
config
=
fluid
.
CheckpointConfig
(
self
.
dirname
,
self
.
max_num_checkpoints
,
self
.
epoch_interval
,
self
.
step_interval
)
trainer_args
=
{}
trainer_args
[
"epoch_id"
]
=
self
.
epoch_id
trainer_args
[
"step_id"
]
=
self
.
step_id
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
program
):
program
.
global_block
().
create_var
(
name
=
"scale_0"
,
psersistable
=
True
,
dtype
=
"float32"
,
shape
=
[
32
,
32
])
exe
=
fluid
.
Executor
(
self
.
place
)
for
i
in
xrange
(
10
):
fluid
.
io
.
save_checkpoint
(
exe
,
config
.
checkpoint_dir
,
self
.
trainer_id
,
trainer_args
,
program
,
config
.
max_num_checkpoints
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_reader_reset.py
0 → 100644
浏览文件 @
1840bd75
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
numpy
as
np
import
unittest
class
TestReaderReset
(
unittest
.
TestCase
):
def
prepare_data
(
self
):
def
fake_data_generator
():
for
n
in
xrange
(
self
.
total_ins_num
):
yield
np
.
ones
(
self
.
ins_shape
)
*
n
,
n
# Prepare data
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
fake_data_generator
,
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
[
3
],
dtype
=
'float32'
),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
self
.
data_file_name
,
reader
,
feeder
)
def
setUp
(
self
):
self
.
use_cuda
=
fluid
.
core
.
is_compiled_with_cuda
()
self
.
data_file_name
=
'./reader_reset_test.recordio'
self
.
ins_shape
=
[
3
]
self
.
batch_size
=
5
self
.
total_ins_num
=
self
.
batch_size
*
20
self
.
test_pass_num
=
100
self
.
prepare_data
()
def
main
(
self
,
with_double_buffer
):
main_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
data_reader_handle
=
fluid
.
layers
.
io
.
open_files
(
filenames
=
[
self
.
data_file_name
],
shapes
=
[[
-
1
]
+
self
.
ins_shape
,
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
],
thread_num
=
1
,
pass_num
=
1
)
data_reader
=
fluid
.
layers
.
io
.
batch
(
data_reader_handle
,
self
.
batch_size
)
if
with_double_buffer
:
data_reader
=
fluid
.
layers
.
double_buffer
(
data_reader
)
image
,
label
=
fluid
.
layers
.
read_file
(
data_reader
)
fetch_list
=
[
image
.
name
,
label
.
name
]
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
build_strategy
=
fluid
.
BuildStrategy
()
if
with_double_buffer
:
build_strategy
.
enable_data_balance
=
True
exec_strategy
=
fluid
.
ExecutionStrategy
()
parallel_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
self
.
use_cuda
,
main_program
=
main_prog
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
data_appeared
=
[
False
]
*
self
.
total_ins_num
pass_count
=
0
while
(
True
):
try
:
data_val
,
label_val
=
parallel_exe
.
run
(
fetch_list
,
return_numpy
=
True
)
ins_num
=
data_val
.
shape
[
0
]
broadcasted_label
=
np
.
ones
((
ins_num
,
)
+
tuple
(
self
.
ins_shape
))
*
label_val
.
reshape
((
ins_num
,
1
))
self
.
assertEqual
(
data_val
.
all
(),
broadcasted_label
.
all
())
for
l
in
label_val
:
self
.
assertFalse
(
data_appeared
[
l
[
0
]])
data_appeared
[
l
[
0
]]
=
True
except
fluid
.
core
.
EOFException
:
pass_count
+=
1
if
with_double_buffer
:
data_appeared
=
data_appeared
[:
-
parallel_exe
.
device_count
*
self
.
batch_size
]
for
i
in
data_appeared
:
self
.
assertTrue
(
i
)
if
pass_count
<
self
.
test_pass_num
:
data_appeared
=
[
False
]
*
self
.
total_ins_num
data_reader_handle
.
reset
()
else
:
break
def
test_all
(
self
):
self
.
main
(
with_double_buffer
=
False
)
self
.
main
(
with_double_buffer
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/trainer.py
浏览文件 @
1840bd75
...
...
@@ -14,6 +14,9 @@
import
contextlib
import
os
import
errno
import
shutil
import
time
import
core
...
...
@@ -94,7 +97,7 @@ class EndStepEvent(object):
class
CheckpointConfig
(
object
):
"""
Parameter object for :code:`
fluid.io.
save_checkpoint` and
Parameter object for :code:`save_checkpoint` and
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
Args:
...
...
@@ -237,7 +240,7 @@ class Trainer(object):
self
.
checkpoint_cfg
=
checkpoint_config
if
self
.
checkpoint_cfg
:
assert
isinstance
(
self
.
checkpoint_cfg
,
CheckpointConfig
)
serial
=
io
.
get_latest_checkpoint_serial
(
serial
=
_
get_latest_checkpoint_serial
(
self
.
checkpoint_cfg
.
checkpoint_dir
)
self
.
checkpoint_cfg
.
load_serial
=
serial
if
serial
>=
0
else
None
...
...
@@ -276,32 +279,15 @@ class Trainer(object):
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
place
)
io
.
load_checkpoint
(
exe
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
epoch_id
,
step_id
=
io
.
load_trainer_args
(
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
checkpoint_cfg
.
load_serial
,
self
.
trainer_id
,
self
.
_get_checkpoint_load_args
())
self
.
checkpoint_cfg
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint_cfg
.
step_id
=
int
(
step_id
)
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
io
.
load_lookup_table_vars
(
exe
,
self
.
checkpoint_cfg
.
checkpoint_dir
,
self
.
startup_program
,
self
.
checkpoint_cfg
.
pserver_id
,
self
.
checkpoint_cfg
.
lookup_table_name
)
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
is
not
None
:
self
.
_load_checkpoint
()
if
param_path
and
os
.
path
.
isdir
(
param_path
):
# load params from param_path into scope
io
.
load_persist_vars_without_grad
(
exe
,
dirname
=
param_path
,
program
=
self
.
startup_program
)
io
.
load_persistables
(
executor
=
exe
,
dirname
=
param_path
,
main_program
=
self
.
startup_program
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
...
...
@@ -549,7 +535,7 @@ class Trainer(object):
def
_clean_checkpoint
(
self
):
assert
self
.
checkpoint_cfg
io
.
clean_checkpoint
(
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
)
clean_checkpoint
(
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
)
def
_get_checkpoint_load_args
(
self
):
"""
...
...
@@ -572,7 +558,7 @@ class Trainer(object):
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
exe
=
executor
.
Executor
(
self
.
place
)
io
.
save_checkpoint
(
save_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
trainer_id
=
self
.
trainer_id
,
...
...
@@ -580,6 +566,41 @@ class Trainer(object):
main_program
=
self
.
train_program
,
max_num_checkpoints
=
self
.
checkpoint_cfg
.
max_num_checkpoints
)
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
load_trainer_args
)
if
len
(
trainer_args
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
def
build_feed_var_list
(
program
,
feed_order
):
if
not
isinstance
(
program
,
framework
.
Program
):
...
...
@@ -602,3 +623,610 @@ def build_feed_var_list(program, feed_order):
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
]
return
feed_var_list
# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
LOOKUP_TABLE_DIR
=
"__lookup_table__"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
pserver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
_save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
role_id
=
0
,
is_trainer
=
True
,
load_trainer_args
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
return
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
if
delete_dir
and
not
os
.
listdir
(
checkpoint_dir
):
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if
has_model_dir
:
dirname
=
_get_model_dir
(
dirname
)
io
.
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
_load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
lookup_table_var
=
var
break
assert
lookup_table_var
is
not
None
lookup_table_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
table_file
=
table_name
+
CHECKPOINT_SEPARATOR
+
str
(
pserver_id
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
lookup_table_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
lookup_table_dir
,
table_file
)})
executor
.
run
(
load_prog
)
def
_save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir
=
_get_model_dir
(
dirname
)
io
.
save_vars
(
executor
,
dirname
=
cur_dir
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
def
_save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
checkpoint_notify_program
=
framework
.
Program
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
def
_save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
trainer_args
.
iteritems
():
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
_load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
ret_values
=
[]
for
arg
in
trainer_args
:
cur_file
=
os
.
path
.
join
(
cur_dir
,
arg
)
with
open
(
cur_file
,
'r'
)
as
f
:
contents
=
f
.
read
()
ret_values
.
append
(
contents
.
strip
())
return
ret_values
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
RAW
:
return
False
# @GRAD are named for gradient variables, checkpoint will not save it.
if
"@GRAD"
in
var
.
name
:
return
False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if
".trainer_"
in
var
.
name
:
return
False
# .block is named for distribute train variables, checkpoint will not save it.
if
".block"
in
var
.
name
:
return
False
return
var
.
persistable
def
_make_chekcpoint_dirs
(
dirs
):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert
dirs
is
not
None
if
os
.
path
.
isfile
(
dirs
):
raise
OSError
(
errno
.
ENOTDIR
,
"dirs path shoule be a Directory."
,
dirs
)
if
not
os
.
path
.
isdir
(
dirs
):
try
:
os
.
makedirs
(
dirs
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
EEXIST
:
raise
err
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
_make_chekcpoint_dirs
(
model_dir
)
return
model_dir
def
_get_lookuptable_dir
(
dirname
):
lookuptable_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
_make_chekcpoint_dirs
(
lookuptable_dir
)
return
lookuptable_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
_make_chekcpoint_dirs
(
trainer_dir
)
return
trainer_dir
def
_scroll_delete
(
dirname
,
max_num_checkpoints
=
3
):
dirs
=
os
.
listdir
(
dirname
)
serial_map
=
{}
for
serial
in
dirs
:
serial_num
=
_get_dir_serial
(
serial
)
serial_map
[
serial_num
]
=
serial
if
len
(
serial_map
.
keys
())
<=
max_num_checkpoints
:
return
serials
=
serial_map
.
keys
()
serials
.
sort
(
reverse
=
True
)
serials
=
serials
[
max_num_checkpoints
:]
for
serial
in
serials
:
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
try
:
shutil
.
rmtree
(
cur_dir
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
ENOENT
:
raise
err
def
_write_success
(
dirname
):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file
=
os
.
path
.
join
(
dirname
,
SUCCESS_MARK_FILENAME
)
with
open
(
success_file
,
'a'
)
as
f
:
now
=
time
.
ctime
()
f
.
write
(
now
)
def
_get_latest_checkpoint_serial
(
checkpoint_dir
):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
is _SUCCESS in this dir
"""
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
checkpoint_dir
,
serial
),
MODEL_DIR
,
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
if
success_num
>
current_dir
:
current_dir
=
success_num
return
current_dir
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录