From 62efc8dc58dfc1f1e7ca2efe6645f1e1245bd336 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 10 Jun 2020 21:19:12 +0800 Subject: [PATCH] fix windows adapter --- core/trainer.py | 1 + core/trainers/framework/network.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/trainer.py b/core/trainer.py index 264f71b6..1fd73b37 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -76,6 +76,7 @@ class Trainer(object): _config = envs.load_yaml(config) + self._context["env"] = _config self._context["dataset"] = _config.get("dataset") phases = [] diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 3443ac57..30b73b20 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase): "default_main_program"] = train_program.clone() context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -113,12 +113,12 @@ class PSNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -154,7 +154,7 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] train_program = fluid.Program() startup_program = fluid.Program() scope = fluid.Scope() @@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase): "default_main_program"] = train_program context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ -- GitLab