diff --git a/core/trainer.py b/core/trainer.py index 264f71b6ab554bae3041c65e8cd3e1a0c3412a4f..1fd73b37e56ddfedb60f3bc00d58e32dd4108cdd 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -76,6 +76,7 @@ class Trainer(object): _config = envs.load_yaml(config) + self._context["env"] = _config self._context["dataset"] = _config.get("dataset") phases = [] diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 3443ac57e42202407d8237b519edbfad1b874fae..30b73b200226fe12094ea4703165d4ab68fc3b11 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase): "default_main_program"] = train_program.clone() context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -113,12 +113,12 @@ class PSNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -154,7 +154,7 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] train_program = fluid.Program() startup_program = fluid.Program() scope = fluid.Scope() @@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(envs.get_global_env("phase")) > 1: + if len(self._env["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = envs.get_global_env("phase")[0] + model_dict = self._env["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase): "default_main_program"] = train_program context["dataset"] = {} - for dataset in envs.get_global_env("dataset"): + for dataset in self._env["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[