提交 68208bc4 编写于 作者: T tangwei

fix windows adapter

上级 fe6fc5b3
...@@ -49,11 +49,11 @@ class ModelBase(object): ...@@ -49,11 +49,11 @@ class ModelBase(object):
self._slot_inited = True self._slot_inited = True
dataset = {} dataset = {}
model_dict = {} model_dict = {}
for i in envs.get_global_env("phase"): for i in self._env("phase"):
if i["name"] == kargs["name"]: if i["name"] == kargs["name"]:
model_dict = i model_dict = i
break break
for i in envs.get_global_env("dataset"): for i in self._env("dataset"):
if i["name"] == model_dict["dataset_name"]: if i["name"] == model_dict["dataset_name"]:
dataset = i dataset = i
break break
......
...@@ -76,6 +76,8 @@ class Trainer(object): ...@@ -76,6 +76,8 @@ class Trainer(object):
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
self._context["dataset"] = _config.get("dataset")
phases = [] phases = []
if phase_names is None: if phase_names is None:
phases = _config.get("phase") phases = _config.get("phase")
......
...@@ -60,8 +60,8 @@ class SingleNetwork(NetworkBase): ...@@ -60,8 +60,8 @@ class SingleNetwork(NetworkBase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
model_path = envs.os_path_adapter( model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"])) envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(
"Model")(None) model_path, "Model")(context["env"])
if context["is_infer"]: if context["is_infer"]:
model._infer_data_var = model.input_data( model._infer_data_var = model.input_data(
...@@ -124,7 +124,8 @@ class PSNetwork(NetworkBase): ...@@ -124,7 +124,8 @@ class PSNetwork(NetworkBase):
model_path = envs.os_path_adapter( model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"])) envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(None) model = envs.lazy_instance_by_fliename(model_path,
"Model")(context["env"])
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + if envs.get_global_env("dataset." + dataset_name +
...@@ -216,8 +217,8 @@ class PslibNetwork(NetworkBase): ...@@ -216,8 +217,8 @@ class PslibNetwork(NetworkBase):
context["model"][model_dict["name"]] = {} context["model"][model_dict["name"]] = {}
model_path = envs.os_path_adapter( model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"])) envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(
"Model")(None) model_path, "Model")(context["env"])
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + if envs.get_global_env("dataset." + dataset_name +
...@@ -282,8 +283,9 @@ class CollectiveNetwork(NetworkBase): ...@@ -282,8 +283,9 @@ class CollectiveNetwork(NetworkBase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
model_path = envs.os_path_adapter( model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"])) envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(model_path,
"Model")(None) "Model")(context["env"])
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + if envs.get_global_env("dataset." + dataset_name +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册