提交 68208bc4 编写于 作者: T tangwei

fix windows adapter

上级 fe6fc5b3
......@@ -49,11 +49,11 @@ class ModelBase(object):
self._slot_inited = True
dataset = {}
model_dict = {}
for i in envs.get_global_env("phase"):
for i in self._env("phase"):
if i["name"] == kargs["name"]:
model_dict = i
break
for i in envs.get_global_env("dataset"):
for i in self._env("dataset"):
if i["name"] == model_dict["dataset_name"]:
dataset = i
break
......
......@@ -76,6 +76,8 @@ class Trainer(object):
_config = envs.load_yaml(config)
self._context["dataset"] = _config.get("dataset")
phases = []
if phase_names is None:
phases = _config.get("phase")
......
......@@ -60,8 +60,8 @@ class SingleNetwork(NetworkBase):
with fluid.scope_guard(scope):
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
model = envs.lazy_instance_by_fliename(
model_path, "Model")(context["env"])
if context["is_infer"]:
model._infer_data_var = model.input_data(
......@@ -124,7 +124,8 @@ class PSNetwork(NetworkBase):
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(None)
model = envs.lazy_instance_by_fliename(model_path,
"Model")(context["env"])
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
......@@ -216,8 +217,8 @@ class PslibNetwork(NetworkBase):
context["model"][model_dict["name"]] = {}
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
model = envs.lazy_instance_by_fliename(
model_path, "Model")(context["env"])
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
......@@ -282,8 +283,9 @@ class CollectiveNetwork(NetworkBase):
with fluid.scope_guard(scope):
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
"Model")(context["env"])
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册