提交 5600b135 编写于 作者: T tangwei12

bug fix

上级 06f6c213
...@@ -472,8 +472,7 @@ def save_checkpoint(executor, ...@@ -472,8 +472,7 @@ def save_checkpoint(executor,
main_program=None, main_program=None,
max_num_checkpoints=3, max_num_checkpoints=3,
lookup_table=None, lookup_table=None,
ps_endpoint_list=None ps_endpoint_list=None):
):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
...@@ -495,14 +494,18 @@ def save_checkpoint(executor, ...@@ -495,14 +494,18 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
is_chief = trainer_id == 0
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args) save_trainer_args(cur_dir, trainer_id, trainer_args)
if trainer_id == 0: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
...@@ -618,7 +621,8 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -618,7 +621,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): def save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
""" """
""" """
cur_dir = _get_lookuptable_dir(dirname) cur_dir = _get_lookuptable_dir(dirname)
...@@ -802,4 +806,3 @@ def get_latest_checkpoint_serial(checkpoint_dir): ...@@ -802,4 +806,3 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if success_num > current_dir: if success_num > current_dir:
current_dir = success_num current_dir = success_num
return current_dir return current_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册