提交 fe6fc5b3 编写于 作者: T tangwei

fix windows adapter

上级 c6a3a9fd
...@@ -58,7 +58,8 @@ class SingleNetwork(NetworkBase): ...@@ -58,7 +58,8 @@ class SingleNetwork(NetworkBase):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
model_path = model_dict["model"] model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(model_path,
"Model")(None) "Model")(None)
...@@ -121,7 +122,8 @@ class PSNetwork(NetworkBase): ...@@ -121,7 +122,8 @@ class PSNetwork(NetworkBase):
context["model"][model_dict["name"]] = {} context["model"][model_dict["name"]] = {}
dataset_name = model_dict["dataset_name"] dataset_name = model_dict["dataset_name"]
model_path = model_dict["model"] model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(None) model = envs.lazy_instance_by_fliename(model_path, "Model")(None)
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
...@@ -212,7 +214,8 @@ class PslibNetwork(NetworkBase): ...@@ -212,7 +214,8 @@ class PslibNetwork(NetworkBase):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
context["model"][model_dict["name"]] = {} context["model"][model_dict["name"]] = {}
model_path = model_dict["model"] model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(model_path,
"Model")(None) "Model")(None)
model._data_var = model.input_data( model._data_var = model.input_data(
...@@ -277,7 +280,8 @@ class CollectiveNetwork(NetworkBase): ...@@ -277,7 +280,8 @@ class CollectiveNetwork(NetworkBase):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
model_path = model_dict["model"] model_path = envs.os_path_adapter(
envs.workspace_adapter(model_dict["model"]))
model = envs.lazy_instance_by_fliename(model_path, model = envs.lazy_instance_by_fliename(model_path,
"Model")(None) "Model")(None)
model._data_var = model.input_data( model._data_var = model.input_data(
......
...@@ -91,9 +91,10 @@ def set_global_envs(envs): ...@@ -91,9 +91,10 @@ def set_global_envs(envs):
fatten_env_namespace([], envs) fatten_env_namespace([], envs)
workspace_adapter() for name, value in global_envs.items():
os_path_adapter() if isinstance(value, str):
reader_adapter() value = os_path_adapter(workspace_adapter(value))
global_envs[name] = value
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
...@@ -118,27 +119,19 @@ def paddlerec_adapter(path): ...@@ -118,27 +119,19 @@ def paddlerec_adapter(path):
return path return path
def os_path_adapter(): def os_path_adapter(value):
for name, value in global_envs.items():
if isinstance(value, str):
if get_platform() == "WINDOWS": if get_platform() == "WINDOWS":
value = value.replace("/", "\\") value = value.replace("/", "\\")
else: else:
value = value.replace("\\", "/") value = value.replace("\\", "/")
global_envs[name] = value return value
def workspace_adapter(): def workspace_adapter(value):
workspace = global_envs.get("workspace") workspace = global_envs.get("workspace")
if not workspace:
return
workspace = paddlerec_adapter(workspace) workspace = paddlerec_adapter(workspace)
for name, value in global_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", workspace) value = value.replace("{workspace}", workspace)
global_envs[name] = value return value
def reader_adapter(): def reader_adapter():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册