From 853cb46730bd003df29bb0d805b408049c9612cd Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 11 Jun 2020 13:24:05 +0800 Subject: [PATCH] fix windows adapter --- core/trainers/framework/runner.py | 6 +++--- core/trainers/framework/startup.py | 4 ++-- models/treebased/tdm/tdm_startup.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index dbcf40d3..e147fbb6 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -310,7 +310,7 @@ class PSRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -337,7 +337,7 @@ class CollectiveRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -362,7 +362,7 @@ class PslibRunner(RunnerBase): def run(self, context): context["fleet"].init_worker() - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index ea355807..80090648 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -73,7 +73,7 @@ class PSStartup(StartupBase): pass def startup(self, context): - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]]["main_program"] @@ -91,7 +91,7 @@ class CollectiveStartup(StartupBase): pass def startup(self, context): - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]][ "default_main_program"] diff --git a/models/treebased/tdm/tdm_startup.py b/models/treebased/tdm/tdm_startup.py index ccc9b6c9..28cd3f73 100644 --- a/models/treebased/tdm/tdm_startup.py +++ b/models/treebased/tdm/tdm_startup.py @@ -47,7 +47,7 @@ class Startup(StartupBase): def _single_startup(self, context): load_tree_from_numpy = envs.get_global_env( "hyper_parameters.tree.load_tree_from_numpy", False) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"]) @@ -106,7 +106,7 @@ class Startup(StartupBase): warmup_model_path = envs.get_global_env( "runner." + context["runner_name"] + ".init_model_path", None) assert warmup_model_path != None, "set runner.init_model_path for loading model" - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"]) -- GitLab