提交 3c334bd7 编写于 作者: T tangwei12

bug fix

上级 1dd14a70
...@@ -560,6 +560,9 @@ class Trainer(object): ...@@ -560,6 +560,9 @@ class Trainer(object):
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0: and step_id % self.checkpoint_cfg.step_interval == 0:
print("_save_checkpoint ...")
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
save_checkpoint( save_checkpoint(
executor=exe, executor=exe,
...@@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_" ...@@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir, checkpoint_dir,
trainer_id, main_program=None,
main_program, trainer_id=0,
trainer_args=None, save_trainer_args=None,
max_num_checkpoints=3,
save_lookup_table=None, save_lookup_table=None,
pserver_endpoints=None): pserver_endpoints=None,
max_num_checkpoints=3):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir` main_program and then saves these variables to the `checkpoint_dir`
...@@ -735,21 +738,18 @@ def save_checkpoint(executor, ...@@ -735,21 +738,18 @@ def save_checkpoint(executor,
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
if main_program is None:
raise ValueError('main_program should not be None.')
if trainer_args:
assert isinstance(trainer_args, dict)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir) _make_chekcpoint_dirs(checkpoint_dir)
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1 serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial, True) cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
_save_trainer_args(cur_dir, trainer_id, trainer_args) is_chief = trainer_id == 0
if save_trainer_args is not None:
_save_trainer_args(cur_dir, trainer_id, save_trainer_args)
if is_chief: if is_chief:
if main_program is None:
raise ValueError('main_program should not be None.')
_save_persistable_vars(executor, cur_dir, main_program) _save_persistable_vars(executor, cur_dir, main_program)
if is_chief and save_lookup_table and pserver_endpoints: if is_chief and save_lookup_table and pserver_endpoints:
...@@ -764,7 +764,7 @@ def load_checkpoint(executor, ...@@ -764,7 +764,7 @@ def load_checkpoint(executor,
main_program=None, main_program=None,
role_id=0, role_id=0,
is_trainer=True, is_trainer=True,
load_models=True, load_models=False,
load_trainer_args=None, load_trainer_args=None,
load_slice_up_vars=None, load_slice_up_vars=None,
load_lookup_table=None): load_lookup_table=None):
...@@ -827,6 +827,10 @@ def load_checkpoint(executor, ...@@ -827,6 +827,10 @@ def load_checkpoint(executor,
_load_persistable_vars(executor, checkpoint_dir, main_program, True) _load_persistable_vars(executor, checkpoint_dir, main_program, True)
return return
if load_trainer_args: if load_trainer_args:
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
format(checkpoint_dir, role_id, load_trainer_args))
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id, trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
load_trainer_args) load_trainer_args)
return trainer_args_ret return trainer_args_ret
...@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir): ...@@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
: param checkpoint_dir : param checkpoint_dir
""" """
if not checkpoint_dir:
return -1
def has_success(checkpoint_dir, cur_dir): def has_success(checkpoint_dir, cur_dir):
""" """
...@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir): ...@@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
""" """
serial = _get_dir_serial(cur_dir) serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir( if serial == -1 or \
os.path.join(checkpoint_dir, cur_dir)): not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1 return -1
success_path = os.path.join( success_path = os.path.join(
...@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir): ...@@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
if os.path.isfile(success_path): if os.path.isfile(success_path):
return serial return serial
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1 current_dir = -1
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
return current_dir
dirs = os.listdir(checkpoint_dir) dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs: for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir) success_num = has_success(checkpoint_dir, cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册