diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 04aa51d2cdd381c802c6c8fa0f094f7ac985833d..84cfc6e0117e8b79b7d501f7a13f70eef0bc88ed 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -520,6 +520,11 @@ class DistributeTranspiler: return var.persistable def get_train_startup_program(self, checkpoint_load_dir=None): + """ + Get train startup program. + If checkpoint_load_dir is None, rerurn default startup program. + IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var. + """ startup_prog = default_startup_program() if not checkpoint_load_dir: @@ -536,7 +541,10 @@ class DistributeTranspiler: attrs={"dir": checkpoint_load_dir}) return startup_prog - def get_startup_program(self, endpoint, pserver_program): + def get_startup_program(self, + endpoint, + pserver_program, + checkpoint_load_dir=None): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -561,6 +569,7 @@ class DistributeTranspiler: created_var_map[var.name] = tmpvar # 2. rename op outputs + load_vars = [] for op in orig_s_prog.global_block().ops: new_inputs = dict() new_outputs = dict() @@ -588,6 +597,16 @@ class DistributeTranspiler: inputs=new_inputs, outputs=new_outputs, attrs=op.attrs) + for var in new_outputs.values(): + load_vars.append(var.name) + # add checkpoint op + if not checkpoint_load_dir: + return s_prog + + s_prog.global_block().append_op( + type="checkpoint_load", + inputs={"X": load_vars}, + attrs={"dir": checkpoint_load_dir}) return s_prog # transpiler function for dis lookup_table