提交 b6ee59ae 编写于 作者: T tangwei12

optimize python checkpint dir config

上级 ee91e48e
...@@ -219,7 +219,8 @@ class DistributeTranspiler: ...@@ -219,7 +219,8 @@ class DistributeTranspiler:
# is_chief (no.0 triner) for checkpoint # is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint # the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint # other trianers will save its own reader offset to checkpoint
self.is_chief = trainer_id == 0 self._is_chief = trainer_id == 0
self.checkpoint_dir = checkpoint_dir
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
...@@ -327,7 +328,7 @@ class DistributeTranspiler: ...@@ -327,7 +328,7 @@ class DistributeTranspiler:
"sync_mode": self.sync_mode "sync_mode": self.sync_mode
}) })
if checkpoint_dir and self.is_chief: if self.checkpoint_dir and self._is_chief:
program.global_block().create_var( program.global_block().create_var(
name=SERIAL_VAR_NAME, name=SERIAL_VAR_NAME,
persistable=True, persistable=True,
...@@ -342,7 +343,7 @@ class DistributeTranspiler: ...@@ -342,7 +343,7 @@ class DistributeTranspiler:
type="checkpoint_save", type="checkpoint_save",
inputs={"X": save_vars}, inputs={"X": save_vars},
attrs={"overwrite": True, attrs={"overwrite": True,
"dir": checkpoint_dir}) "dir": self.checkpoint_dir})
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
...@@ -524,15 +525,15 @@ class DistributeTranspiler: ...@@ -524,15 +525,15 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def get_train_startup_program(self, checkpoint_load_dir=None): def get_train_startup_program(self):
""" """
Get train startup program. Get train startup program.
If checkpoint_load_dir is None, rerurn default startup program. If self.checkpoint_dir is None, rerurn default startup program.
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var. IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var.
""" """
startup_prog = default_startup_program() startup_prog = default_startup_program()
if not checkpoint_load_dir: if not self.checkpoint_dir:
return startup_prog return startup_prog
load_vars = [] load_vars = []
...@@ -540,20 +541,17 @@ class DistributeTranspiler: ...@@ -540,20 +541,17 @@ class DistributeTranspiler:
if self._is_persistable(var): if self._is_persistable(var):
load_vars.append(var.name) load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
startup_prog.global_block().append_op( startup_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
inputs={"X": load_vars}, inputs={"X": load_vars},
outputs={"Argv": []}, outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir, attrs={"dir": self.checkpoint_dir,
"Serial": serial_number}) "Serial": serial_number})
return startup_prog return startup_prog
def get_startup_program(self, def get_startup_program(self, endpoint, pserver_program):
endpoint,
pserver_program,
checkpoint_load_dir=None):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
...@@ -609,16 +607,16 @@ class DistributeTranspiler: ...@@ -609,16 +607,16 @@ class DistributeTranspiler:
for var in new_outputs.values(): for var in new_outputs.values():
load_vars.append(var.name) load_vars.append(var.name)
# add checkpoint op # add checkpoint op
if not checkpoint_load_dir: if not self.checkpoint_dir:
return s_prog return s_prog
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
s_prog.global_block().append_op( s_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
inputs={"X": load_vars}, inputs={"X": load_vars},
outputs={"Argv": []}, outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir, attrs={"dir": self.checkpoint_dir,
"Serial": serial_number}) "Serial": serial_number})
return s_prog return s_prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册