diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index dbcf40d3a3fd9277f39cbfb8499b6ce73e1e127d..e147fbb6d2511c1d9010507b4f093d6df2e9d3f4 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -310,7 +310,7 @@ class PSRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -337,7 +337,7 @@ class CollectiveRunner(RunnerBase): epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] for epoch in range(epochs): begin_time = time.time() self._run(context, model_dict) @@ -362,7 +362,7 @@ class PslibRunner(RunnerBase): def run(self, context): context["fleet"].init_worker() - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] epochs = int( envs.get_global_env("runner." + context["runner_name"] + ".epochs")) diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index ea3558072bb1d61ecb9d63c07ff5d6b71a37ea22..800906483c0259f4b53942831082285fb5c19fae 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -73,7 +73,7 @@ class PSStartup(StartupBase): pass def startup(self, context): - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]]["main_program"] @@ -91,7 +91,7 @@ class CollectiveStartup(StartupBase): pass def startup(self, context): - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): train_prog = context["model"][model_dict["name"]][ "default_main_program"] diff --git a/models/treebased/tdm/tdm_startup.py b/models/treebased/tdm/tdm_startup.py index ccc9b6c96372321a04730f7d08f9c92d9504313a..28cd3f737e6fc10a6c1f0514e42e04aafab97922 100644 --- a/models/treebased/tdm/tdm_startup.py +++ b/models/treebased/tdm/tdm_startup.py @@ -47,7 +47,7 @@ class Startup(StartupBase): def _single_startup(self, context): load_tree_from_numpy = envs.get_global_env( "hyper_parameters.tree.load_tree_from_numpy", False) - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"]) @@ -106,7 +106,7 @@ class Startup(StartupBase): warmup_model_path = envs.get_global_env( "runner." + context["runner_name"] + ".init_model_path", None) assert warmup_model_path != None, "set runner.init_model_path for loading model" - model_dict = envs.get_global_env("phase")[0] + model_dict = context("env")["phase"][0] with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): context["exe"].run(context["model"][model_dict["name"]][ "startup_program"])