提交 4220b31d 编写于 作者: T tangwei12

update pserver startup

上级 3dd27465
...@@ -520,6 +520,11 @@ class DistributeTranspiler: ...@@ -520,6 +520,11 @@ class DistributeTranspiler:
return var.persistable return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None): def get_train_startup_program(self, checkpoint_load_dir=None):
"""
Get train startup program.
If checkpoint_load_dir is None, rerurn default startup program.
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program() startup_prog = default_startup_program()
if not checkpoint_load_dir: if not checkpoint_load_dir:
...@@ -536,7 +541,10 @@ class DistributeTranspiler: ...@@ -536,7 +541,10 @@ class DistributeTranspiler:
attrs={"dir": checkpoint_load_dir}) attrs={"dir": checkpoint_load_dir})
return startup_prog return startup_prog
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self,
endpoint,
pserver_program,
checkpoint_load_dir=None):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
...@@ -561,6 +569,7 @@ class DistributeTranspiler: ...@@ -561,6 +569,7 @@ class DistributeTranspiler:
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
# 2. rename op outputs # 2. rename op outputs
load_vars = []
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_inputs = dict() new_inputs = dict()
new_outputs = dict() new_outputs = dict()
...@@ -588,6 +597,16 @@ class DistributeTranspiler: ...@@ -588,6 +597,16 @@ class DistributeTranspiler:
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.attrs) attrs=op.attrs)
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not checkpoint_load_dir:
return s_prog
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir})
return s_prog return s_prog
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册