提交 b6ee59ae 编写于 作者: T tangwei12

optimize python checkpint dir config

上级 ee91e48e
......@@ -219,7 +219,8 @@ class DistributeTranspiler:
# is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint
self.is_chief = trainer_id == 0
self._is_chief = trainer_id == 0
self.checkpoint_dir = checkpoint_dir
# process lookup_table_op
# 1. check all lookup_table_op is distributed
......@@ -327,7 +328,7 @@ class DistributeTranspiler:
"sync_mode": self.sync_mode
})
if checkpoint_dir and self.is_chief:
if self.checkpoint_dir and self._is_chief:
program.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
......@@ -342,7 +343,7 @@ class DistributeTranspiler:
type="checkpoint_save",
inputs={"X": save_vars},
attrs={"overwrite": True,
"dir": checkpoint_dir})
"dir": self.checkpoint_dir})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
......@@ -524,15 +525,15 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def get_train_startup_program(self, checkpoint_load_dir=None):
def get_train_startup_program(self):
"""
Get train startup program.
If checkpoint_load_dir is None, rerurn default startup program.
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var.
If self.checkpoint_dir is None, rerurn default startup program.
IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program()
if not checkpoint_load_dir:
if not self.checkpoint_dir:
return startup_prog
load_vars = []
......@@ -540,20 +541,17 @@ class DistributeTranspiler:
if self._is_persistable(var):
load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
startup_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir,
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return startup_prog
def get_startup_program(self,
endpoint,
pserver_program,
checkpoint_load_dir=None):
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
......@@ -609,16 +607,16 @@ class DistributeTranspiler:
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not checkpoint_load_dir:
if not self.checkpoint_dir:
return s_prog
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir,
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return s_prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册