提交 d081256c 编写于 作者: T tangwei12

add api in distribute transpiler

上级 0334d494
......@@ -315,10 +315,21 @@ class DistributeTranspiler:
"sync_mode": self.sync_mode
})
serial_var = program.global_block().create_var(
name="SERIAL_NUMBER",
persistable=True,
type=core.VarDesc.VarType.RAW)
save_vars = []
for var in self.origin_program.list_vars():
if self.is_persistable(var):
save_vars.append(var.name)
program.global_block().append_op(
type="checkpoint_save",
inputs={"X": send_outputs},
attrs={"overwrite": True,
inputs={"X": save_vars},
outputs={"Serial": serial_var},
attrs={"overwrite": False,
"dir": "/workspace/ckpt/"})
# step4: Concat the parameters splits together after recv.
......@@ -501,6 +512,27 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def is_persistable(self, var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None):
startup_prog = default_startup_program()
if not checkpoint_load_dir:
return startup_prog
for var in startup_prog.list_vars():
if self.is_persistable(var):
print("var: %s" % var.name)
startup_prog.global_block().append_op(
type="checkpoint_load", attrs={"dir": checkpoint_load_dir})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册