From 4220b31d4f45918fbc0a74cc05ba14ffd4ab093c Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 16 May 2018 20:50:24 +0800 Subject: [PATCH] update pserver startup --- .../fluid/transpiler/distribute_transpiler.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 04aa51d2cd..84cfc6e011 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -520,6 +520,11 @@ class DistributeTranspiler: return var.persistable def get_train_startup_program(self, checkpoint_load_dir=None): + """ + Get train startup program. + If checkpoint_load_dir is None, rerurn default startup program. + IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var. + """ startup_prog = default_startup_program() if not checkpoint_load_dir: @@ -536,7 +541,10 @@ class DistributeTranspiler: attrs={"dir": checkpoint_load_dir}) return startup_prog - def get_startup_program(self, endpoint, pserver_program): + def get_startup_program(self, + endpoint, + pserver_program, + checkpoint_load_dir=None): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -561,6 +569,7 @@ class DistributeTranspiler: created_var_map[var.name] = tmpvar # 2. rename op outputs + load_vars = [] for op in orig_s_prog.global_block().ops: new_inputs = dict() new_outputs = dict() @@ -588,6 +597,16 @@ class DistributeTranspiler: inputs=new_inputs, outputs=new_outputs, attrs=op.attrs) + for var in new_outputs.values(): + load_vars.append(var.name) + # add checkpoint op + if not checkpoint_load_dir: + return s_prog + + s_prog.global_block().append_op( + type="checkpoint_load", + inputs={"X": load_vars}, + attrs={"dir": checkpoint_load_dir}) return s_prog # transpiler function for dis lookup_table -- GitLab