未验证 提交 6c91b6a8 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #99 from MrChengmo/fix_load

fix load
...@@ -38,14 +38,9 @@ class StartupBase(object): ...@@ -38,14 +38,9 @@ class StartupBase(object):
if dirname is None or dirname == "": if dirname is None or dirname == "":
return return
print("going to load ", dirname) print("going to load ", dirname)
if is_fleet:
if context["fleet_mode"].upper() == "PS":
return
# For Pslib
context["fleet"].load_persistables(context["exe"], dirname)
else:
fluid.io.load_persistables( fluid.io.load_persistables(
context["exe"], dirname, main_program=main_program) context["exe"], dirname, main_program=main_program)
print("load from {} success".format(dirname))
class SingleStartup(StartupBase): class SingleStartup(StartupBase):
...@@ -84,7 +79,6 @@ class PSStartup(StartupBase): ...@@ -84,7 +79,6 @@ class PSStartup(StartupBase):
"startup_program"] "startup_program"]
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
context["exe"].run(startup_prog) context["exe"].run(startup_prog)
self.load(context, True)
context["status"] = "train_pass" context["status"] = "train_pass"
...@@ -102,7 +96,7 @@ class CollectiveStartup(StartupBase): ...@@ -102,7 +96,7 @@ class CollectiveStartup(StartupBase):
"startup_program"] "startup_program"]
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
context["exe"].run(startup_prog) context["exe"].run(startup_prog)
self.load(context, True) self.load(context, main_program=train_prog)
context["status"] = "train_pass" context["status"] = "train_pass"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册