提交 3c334bd7 编写于 作者: T tangwei12

bug fix

上级 1dd14a70
......@@ -560,6 +560,9 @@ class Trainer(object):
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0:
print("_save_checkpoint ...")
exe = executor.Executor(self.place)
save_checkpoint(
executor=exe,
......@@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
main_program,
trainer_args=None,
max_num_checkpoints=3,
main_program=None,
trainer_id=0,
save_trainer_args=None,
save_lookup_table=None,
pserver_endpoints=None):
pserver_endpoints=None,
max_num_checkpoints=3):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
......@@ -735,21 +738,18 @@ def save_checkpoint(executor,
if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None")
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args:
assert isinstance(trainer_args, dict)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
_save_trainer_args(cur_dir, trainer_id, trainer_args)
is_chief = trainer_id == 0
if save_trainer_args is not None:
_save_trainer_args(cur_dir, trainer_id, save_trainer_args)
if is_chief:
if main_program is None:
raise ValueError('main_program should not be None.')
_save_persistable_vars(executor, cur_dir, main_program)
if is_chief and save_lookup_table and pserver_endpoints:
......@@ -764,7 +764,7 @@ def load_checkpoint(executor,
main_program=None,
role_id=0,
is_trainer=True,
load_models=True,
load_models=False,
load_trainer_args=None,
load_slice_up_vars=None,
load_lookup_table=None):
......@@ -827,6 +827,10 @@ def load_checkpoint(executor,
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
return
if load_trainer_args:
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
format(checkpoint_dir, role_id, load_trainer_args))
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args)
return trainer_args_ret
......@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
: param checkpoint_dir
"""
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir):
"""
......@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
"""
serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
if serial == -1 or \
not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
......@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
if os.path.isfile(success_path):
return serial
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
return current_dir
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册