提交 fe6fc5b3 编写于 作者: T tangwei

fix windows adapter

上级 c6a3a9fd
......@@ -58,7 +58,8 @@ class SingleNetwork(NetworkBase):
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
model_path = model_dict["model"]
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
......@@ -121,7 +122,8 @@ class PSNetwork(NetworkBase):
context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"]
model_path = model_dict["model"]
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(None)
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
......@@ -212,7 +214,8 @@ class PslibNetwork(NetworkBase):
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
context["model"][model_dict["name"]] = {}
model_path = model_dict["model"]
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
model._data_var = model.input_data(
......@@ -277,7 +280,8 @@ class CollectiveNetwork(NetworkBase):
scope = fluid.Scope()
with fluid.program_guard(train_program, startup_program):
with fluid.scope_guard(scope):
model_path = model_dict["model"]
model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path,
"Model")(None)
model._data_var = model.input_data(
......
......@@ -91,9 +91,10 @@ def set_global_envs(envs):
fatten_env_namespace([], envs)
workspace_adapter()
os_path_adapter()
reader_adapter()
for name, value in global_envs.items():
if isinstance(value, str):
value = os_path_adapter(workspace_adapter(value))
global_envs[name] = value
def get_global_env(env_name, default_value=None, namespace=None):
......@@ -118,27 +119,19 @@ def paddlerec_adapter(path):
return path
def os_path_adapter():
for name, value in global_envs.items():
if isinstance(value, str):
def os_path_adapter(value):
if get_platform() == "WINDOWS":
value = value.replace("/", "\\")
else:
value = value.replace("\\", "/")
global_envs[name] = value
return value
def workspace_adapter():
def workspace_adapter(value):
workspace = global_envs.get("workspace")
if not workspace:
return
workspace = paddlerec_adapter(workspace)
for name, value in global_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", workspace)
global_envs[name] = value
return value
def reader_adapter():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册