提交 8a178165 编写于 作者: T tangwei12

add lookuo table in python

上级 a8959162
......@@ -500,6 +500,7 @@ def save_checkpoint(executor,
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, "")
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......@@ -530,7 +531,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
:param checkpoint_dir
......@@ -598,6 +600,23 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, epmap):
"""
"""
cur_dir = _get_lookuptable_dir(dirname)
checkpoint_notify_program = Program()
checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {}
attrs['epmap'] = None
attrs['dir'] = cur_dir
checkpoint_notify_block.append_op(
type='checkpointnotify', inputs={}, output={}, attrs=attrs)
executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)
......@@ -680,6 +699,15 @@ def _get_model_dir(dirname):
return model_dir
def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
if not os.path.isdir(lookuptable_dir):
os.makedirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
......
......@@ -446,7 +446,8 @@ class Trainer(object):
def _save_checkpoint(self, epoch_id, step_id):
assert self.checkpoint_cfg
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0:
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
executor=exe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册