From 211f7e38b8f6f478fc38ed87b49f6718db619c8a Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 11 Jun 2020 13:47:47 +0800 Subject: [PATCH] fix windows adapter --- core/trainers/framework/runner.py | 6 +++--- core/trainers/framework/startup.py | 4 ++-- models/treebased/tdm/tdm_startup.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index e147fbb6..7c61c412 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -310,7 +310,7 @@ class PSRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -337,7 +337,7 @@ class CollectiveRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -362,7 +362,7 @@ class PslibRunner(RunnerBase): def run(self, context): context["fleet"].init_worker() - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index 80090648..82e24472 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -73,7 +73,7 @@ class PSStartup(StartupBase): pass def startup(self, context): - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]]["main_program"] @@ -91,7 +91,7 @@ class CollectiveStartup(StartupBase): pass def startup(self, context): - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]][ "default_main_program"] diff --git a/models/treebased/tdm/tdm_startup.py b/models/treebased/tdm/tdm_startup.py index 28cd3f73..3f1b87db 100644 --- a/models/treebased/tdm/tdm_startup.py +++ b/models/treebased/tdm/tdm_startup.py @@ -47,7 +47,7 @@ class Startup(StartupBase): def _single_startup(self, context): load_tree_from_numpy = envs.get_global_env( "hyper_parameters.tree.load_tree_from_numpy", False) - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"]) @@ -106,7 +106,7 @@ class Startup(StartupBase): warmup_model_path = envs.get_global_env( "runner." + context["runner_name"] + ".init_model_path", None) assert warmup_model_path != None, "set runner.init_model_path for loading model" - model_dict = context("env")["phase"][0] + model_dict = context["env"]["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"]) -- GitLab