提交 5600b135 编写于 作者: T tangwei12

bug fix

上级 06f6c213
......@@ -472,8 +472,7 @@ def save_checkpoint(executor,
main_program=None,
max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None
):
ps_endpoint_list=None):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
......@@ -495,14 +494,18 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
is_chief = trainer_id == 0
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args)
if trainer_id == 0:
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list)
if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......@@ -618,7 +621,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
"""
"""
cur_dir = _get_lookuptable_dir(dirname)
......@@ -802,4 +806,3 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if success_num > current_dir:
current_dir = success_num
return current_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册