From b6ee59ae2573fbbe66ab574be299d6b6fe52552c Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 18 May 2018 22:24:24 +0800 Subject: [PATCH] optimize python checkpint dir config --- .../fluid/transpiler/distribute_transpiler.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 8b379ddcf8..dc9d254fa5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -219,7 +219,8 @@ class DistributeTranspiler: # is_chief (no.0 triner) for checkpoint # the no.0 trainer will save all variables and its own reader offset to checkpoint # other trianers will save its own reader offset to checkpoint - self.is_chief = trainer_id == 0 + self._is_chief = trainer_id == 0 + self.checkpoint_dir = checkpoint_dir # process lookup_table_op # 1. check all lookup_table_op is distributed @@ -327,7 +328,7 @@ class DistributeTranspiler: "sync_mode": self.sync_mode }) - if checkpoint_dir and self.is_chief: + if self.checkpoint_dir and self._is_chief: program.global_block().create_var( name=SERIAL_VAR_NAME, persistable=True, @@ -342,7 +343,7 @@ class DistributeTranspiler: type="checkpoint_save", inputs={"X": save_vars}, attrs={"overwrite": True, - "dir": checkpoint_dir}) + "dir": self.checkpoint_dir}) # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): @@ -524,15 +525,15 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program - def get_train_startup_program(self, checkpoint_load_dir=None): + def get_train_startup_program(self): """ Get train startup program. - If checkpoint_load_dir is None, rerurn default startup program. - IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var. + If self.checkpoint_dir is None, rerurn default startup program. + IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var. """ startup_prog = default_startup_program() - if not checkpoint_load_dir: + if not self.checkpoint_dir: return startup_prog load_vars = [] @@ -540,20 +541,17 @@ class DistributeTranspiler: if self._is_persistable(var): load_vars.append(var.name) - serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) + serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir) startup_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, outputs={"Argv": []}, - attrs={"dir": checkpoint_load_dir, + attrs={"dir": self.checkpoint_dir, "Serial": serial_number}) return startup_prog - def get_startup_program(self, - endpoint, - pserver_program, - checkpoint_load_dir=None): + def get_startup_program(self, endpoint, pserver_program): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -609,16 +607,16 @@ class DistributeTranspiler: for var in new_outputs.values(): load_vars.append(var.name) # add checkpoint op - if not checkpoint_load_dir: + if not self.checkpoint_dir: return s_prog - serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) + serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir) s_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, outputs={"Argv": []}, - attrs={"dir": checkpoint_load_dir, + attrs={"dir": self.checkpoint_dir, "Serial": serial_number}) return s_prog -- GitLab