提交 3c334bd7 编写于 作者: T tangwei12

bug fix

上级 1dd14a70
......@@ -73,7 +73,7 @@ class BeginStepEvent(object):
self.step = step_id
self.fetch_metrics = True
"""
If fetch_metrics is true, the metrics will be fetched at the
If fetch_metrics is true, the metrics will be fetched at the
EndStepEvent. Default is True.
"""
......@@ -560,6 +560,9 @@ class Trainer(object):
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0:
print("_save_checkpoint ...")
exe = executor.Executor(self.place)
save_checkpoint(
executor=exe,
......@@ -604,7 +607,7 @@ class Trainer(object):
self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0])
self.checkpoint_cfg.step_id = int(trainer_args_ret[1])
# Pserver Load
# Pserver Load
else:
# load slice_vars
if self.slice_vars != None and len(self.slice_vars) != 0:
......@@ -661,22 +664,22 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
main_program,
trainer_args=None,
max_num_checkpoints=3,
main_program=None,
trainer_id=0,
save_trainer_args=None,
save_lookup_table=None,
pserver_endpoints=None):
pserver_endpoints=None,
max_num_checkpoints=3):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
......@@ -688,21 +691,21 @@ def save_checkpoint(executor,
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
save_lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
......@@ -735,21 +738,18 @@ def save_checkpoint(executor,
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args:
assert isinstance(trainer_args, dict)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
_save_trainer_args(cur_dir, trainer_id, trainer_args)
is_chief = trainer_id == 0
if save_trainer_args is not None:
_save_trainer_args(cur_dir, trainer_id, save_trainer_args)
if is_chief:
if main_program is None:
raise ValueError('main_program should not be None.')
_save_persistable_vars(executor, cur_dir, main_program)
if is_chief and save_lookup_table and pserver_endpoints:
......@@ -764,7 +764,7 @@ def load_checkpoint(executor,
main_program=None,
role_id=0,
is_trainer=True,
load_models=True,
load_models=False,
load_trainer_args=None,
load_slice_up_vars=None,
load_lookup_table=None):
......@@ -774,8 +774,8 @@ def load_checkpoint(executor,
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
......@@ -827,6 +827,10 @@ def load_checkpoint(executor,
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
return
if load_trainer_args:
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
format(checkpoint_dir, role_id, load_trainer_args))
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args)
return trainer_args_ret
......@@ -842,9 +846,9 @@ def load_checkpoint(executor,
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally,
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
......@@ -954,7 +958,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars):
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
......@@ -1005,7 +1009,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
def _save_persistable_vars(executor, dirname, program):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
......@@ -1034,7 +1038,7 @@ def _save_persistable_vars(executor, dirname, program):
# In this example, `_save_persistable_vars` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir = _get_model_dir(dirname)
......@@ -1053,7 +1057,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
......@@ -1061,13 +1065,13 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
......@@ -1078,7 +1082,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir = _get_lookuptable_dir(dirname)
......@@ -1110,7 +1114,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
......@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
: param checkpoint_dir
"""
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir):
"""
......@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
"""
serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
if serial == -1 or \
not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
......@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
if os.path.isfile(success_path):
return serial
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
return current_dir
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册