提交 9cf47afe 编写于 作者: T tangwei12

modify get trainer param

上级 886897cc
...@@ -525,12 +525,15 @@ class DistributeTranspiler: ...@@ -525,12 +525,15 @@ class DistributeTranspiler:
if not checkpoint_load_dir: if not checkpoint_load_dir:
return startup_prog return startup_prog
load_vars = []
for var in startup_prog.list_vars(): for var in startup_prog.list_vars():
if self.is_persistable(var): if self.is_persistable(var):
print("var: %s" % var.name) load_vars.append(var.name)
startup_prog.global_block().append_op( startup_prog.global_block().append_op(
type="checkpoint_load", attrs={"dir": checkpoint_load_dir}) type="checkpoint_load",
outputs={"Out": load_vars},
attrs={"dir": checkpoint_load_dir})
return startup_prog return startup_prog
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self, endpoint, pserver_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册