提交 9cf47afe 编写于 作者: T tangwei12

modify get trainer param

上级 886897cc
......@@ -525,12 +525,15 @@ class DistributeTranspiler:
if not checkpoint_load_dir:
return startup_prog
load_vars = []
for var in startup_prog.list_vars():
if self.is_persistable(var):
print("var: %s" % var.name)
load_vars.append(var.name)
startup_prog.global_block().append_op(
type="checkpoint_load", attrs={"dir": checkpoint_load_dir})
type="checkpoint_load",
outputs={"Out": load_vars},
attrs={"dir": checkpoint_load_dir})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册