提交 4220b31d 编写于 作者: T tangwei12

update pserver startup

上级 3dd27465
......@@ -520,6 +520,11 @@ class DistributeTranspiler:
return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None):
"""
Get train startup program.
If checkpoint_load_dir is None, rerurn default startup program.
IF checkpoint_load_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program()
if not checkpoint_load_dir:
......@@ -536,7 +541,10 @@ class DistributeTranspiler:
attrs={"dir": checkpoint_load_dir})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):
def get_startup_program(self,
endpoint,
pserver_program,
checkpoint_load_dir=None):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
......@@ -561,6 +569,7 @@ class DistributeTranspiler:
created_var_map[var.name] = tmpvar
# 2. rename op outputs
load_vars = []
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
......@@ -588,6 +597,16 @@ class DistributeTranspiler:
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not checkpoint_load_dir:
return s_prog
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir})
return s_prog
# transpiler function for dis lookup_table
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册