提交 211f7e38 编写于 作者: T tangwei

fix windows adapter

上级 853cb467
...@@ -310,7 +310,7 @@ class PSRunner(RunnerBase): ...@@ -310,7 +310,7 @@ class PSRunner(RunnerBase):
epochs = int( epochs = int(
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".epochs")) ".epochs"))
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
for epoch in range(epochs): for epoch in range(epochs):
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) self._run(context, model_dict)
...@@ -337,7 +337,7 @@ class CollectiveRunner(RunnerBase): ...@@ -337,7 +337,7 @@ class CollectiveRunner(RunnerBase):
epochs = int( epochs = int(
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".epochs")) ".epochs"))
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
for epoch in range(epochs): for epoch in range(epochs):
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) self._run(context, model_dict)
...@@ -362,7 +362,7 @@ class PslibRunner(RunnerBase): ...@@ -362,7 +362,7 @@ class PslibRunner(RunnerBase):
def run(self, context): def run(self, context):
context["fleet"].init_worker() context["fleet"].init_worker()
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
epochs = int( epochs = int(
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".epochs")) ".epochs"))
......
...@@ -73,7 +73,7 @@ class PSStartup(StartupBase): ...@@ -73,7 +73,7 @@ class PSStartup(StartupBase):
pass pass
def startup(self, context): def startup(self, context):
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
train_prog = context["model"][model_dict["name"]]["main_program"] train_prog = context["model"][model_dict["name"]]["main_program"]
...@@ -91,7 +91,7 @@ class CollectiveStartup(StartupBase): ...@@ -91,7 +91,7 @@ class CollectiveStartup(StartupBase):
pass pass
def startup(self, context): def startup(self, context):
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
"default_main_program"] "default_main_program"]
......
...@@ -47,7 +47,7 @@ class Startup(StartupBase): ...@@ -47,7 +47,7 @@ class Startup(StartupBase):
def _single_startup(self, context): def _single_startup(self, context):
load_tree_from_numpy = envs.get_global_env( load_tree_from_numpy = envs.get_global_env(
"hyper_parameters.tree.load_tree_from_numpy", False) "hyper_parameters.tree.load_tree_from_numpy", False)
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
context["exe"].run(context["model"][model_dict["name"]][ context["exe"].run(context["model"][model_dict["name"]][
"startup_program"]) "startup_program"])
...@@ -106,7 +106,7 @@ class Startup(StartupBase): ...@@ -106,7 +106,7 @@ class Startup(StartupBase):
warmup_model_path = envs.get_global_env( warmup_model_path = envs.get_global_env(
"runner." + context["runner_name"] + ".init_model_path", None) "runner." + context["runner_name"] + ".init_model_path", None)
assert warmup_model_path != None, "set runner.init_model_path for loading model" assert warmup_model_path != None, "set runner.init_model_path for loading model"
model_dict = context("env")["phase"][0] model_dict = context["env"]["phase"][0]
with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]):
context["exe"].run(context["model"][model_dict["name"]][ context["exe"].run(context["model"][model_dict["name"]][
"startup_program"]) "startup_program"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册