diff --git a/core/model.py b/core/model.py index 036de7d619319611ac028b6c67ee5f8b24ff6b35..75324d406d61cdb69726203d022a792632775651 100755 --- a/core/model.py +++ b/core/model.py @@ -37,7 +37,7 @@ class ModelBase(object): self._fetch_interval = 20 self._platform = envs.get_platform() self._init_hyper_parameters() - self._env = config + context["env"] = config self._slot_inited = False def _init_hyper_parameters(self): @@ -49,11 +49,11 @@ class ModelBase(object): self._slot_inited = True dataset = {} model_dict = {} - for i in self._env("phase"): + for i in context["env"]("phase"): if i["name"] == kargs["name"]: model_dict = i break - for i in self._env("dataset"): + for i in context["env"]("dataset"): if i["name"] == model_dict["dataset_name"]: dataset = i break diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 30b73b200226fe12094ea4703165d4ab68fc3b11..7d1e276ea2d1496a2225572d917fcb3f9b2fc1fb 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -96,7 +96,7 @@ class SingleNetwork(NetworkBase): "default_main_program"] = train_program.clone() context["dataset"] = {} - for dataset in self._env["dataset"]: + for dataset in context["env"]["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -113,12 +113,12 @@ class PSNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(self._env["phase"]) > 1: + if len(context["env"]["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = self._env["phase"][0] + model_dict = context["env"]["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -154,7 +154,7 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in self._env["dataset"]: + for dataset in context["env"]["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -200,12 +200,12 @@ class PslibNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(self._env["phase"]) > 1: + if len(context["env"]["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = self._env["phase"][0] + model_dict = context["env"]["phase"][0] train_program = fluid.Program() startup_program = fluid.Program() scope = fluid.Scope() @@ -247,7 +247,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in self._env["dataset"]: + for dataset in context["env"]["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[ @@ -267,12 +267,12 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} - if len(self._env["phase"]) > 1: + if len(context["env"]["phase"]) > 1: warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, stacklevel=2) - model_dict = self._env["phase"][0] + model_dict = context["env"]["phase"][0] context["model"][model_dict["name"]] = {} dataset_name = model_dict["dataset_name"] @@ -311,7 +311,7 @@ class CollectiveNetwork(NetworkBase): "default_main_program"] = train_program context["dataset"] = {} - for dataset in self._env["dataset"]: + for dataset in context["env"]["dataset"]: if dataset["type"] != "DataLoader": dataset_class = QueueDataset(context) context["dataset"][dataset[