diff --git a/core/model.py b/core/model.py index 83181875fa3eb64d9c4085aa68356ade6f2ef720..036de7d619319611ac028b6c67ee5f8b24ff6b35 100755 --- a/core/model.py +++ b/core/model.py @@ -49,11 +49,11 @@ class ModelBase(object): self._slot_inited = True dataset = {} model_dict = {} - for i in envs.get_global_env("phase"): + for i in self._env("phase"): if i["name"] == kargs["name"]: model_dict = i break - for i in envs.get_global_env("dataset"): + for i in self._env("dataset"): if i["name"] == model_dict["dataset_name"]: dataset = i break diff --git a/core/trainer.py b/core/trainer.py index 83a7ea1ac98da07fd4ee729a57957e50b36e87a2..264f71b6ab554bae3041c65e8cd3e1a0c3412a4f 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -76,6 +76,8 @@ class Trainer(object): _config = envs.load_yaml(config) + self._context["dataset"] = _config.get("dataset") + phases = [] if phase_names is None: phases = _config.get("phase") diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 6685842ccd4c3378b1ed0b019a06249626f3a412..3443ac57e42202407d8237b519edbfad1b874fae 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -60,8 +60,8 @@ class SingleNetwork(NetworkBase): with fluid.scope_guard(scope): model_path = envs.os_path_adapter( envs.workspace_adapter(model_dict["model"])) - model = envs.lazy_instance_by_fliename(model_path, - "Model")(None) + model = envs.lazy_instance_by_fliename( + model_path, "Model")(context["env"]) if context["is_infer"]: model._infer_data_var = model.input_data( @@ -124,7 +124,8 @@ class PSNetwork(NetworkBase): model_path = envs.os_path_adapter( envs.workspace_adapter(model_dict["model"])) - model = envs.lazy_instance_by_fliename(model_path, "Model")(None) + model = envs.lazy_instance_by_fliename(model_path, + "Model")(context["env"]) model._data_var = model.input_data( dataset_name=model_dict["dataset_name"]) if envs.get_global_env("dataset." + dataset_name + @@ -216,8 +217,8 @@ class PslibNetwork(NetworkBase): context["model"][model_dict["name"]] = {} model_path = envs.os_path_adapter( envs.workspace_adapter(model_dict["model"])) - model = envs.lazy_instance_by_fliename(model_path, - "Model")(None) + model = envs.lazy_instance_by_fliename( + model_path, "Model")(context["env"]) model._data_var = model.input_data( dataset_name=model_dict["dataset_name"]) if envs.get_global_env("dataset." + dataset_name + @@ -282,8 +283,9 @@ class CollectiveNetwork(NetworkBase): with fluid.scope_guard(scope): model_path = envs.os_path_adapter( envs.workspace_adapter(model_dict["model"])) + model = envs.lazy_instance_by_fliename(model_path, - "Model")(None) + "Model")(context["env"]) model._data_var = model.input_data( dataset_name=model_dict["dataset_name"]) if envs.get_global_env("dataset." + dataset_name +