提交 3c915be5 编写于 作者: T tangwei

windows path adapt

上级 72e086b5
...@@ -89,7 +89,7 @@ def get_global_envs(): ...@@ -89,7 +89,7 @@ def get_global_envs():
return global_envs return global_envs
def windows_path_adapter(path): def path_adapter(path):
def adapt(l_p): def adapt(l_p):
if get_platform() == "WINDOWS": if get_platform() == "WINDOWS":
adapted_p = l_p.split("paddlerec.")[1].replace(".", "\\") adapted_p = l_p.split("paddlerec.")[1].replace(".", "\\")
...@@ -104,16 +104,23 @@ def windows_path_adapter(path): ...@@ -104,16 +104,23 @@ def windows_path_adapter(path):
return adapt(path) return adapt(path)
def windows_path_converter(path):
if get_platform() == "WINDOWS":
return path.replace("/", "\\")
else:
return path.replace("\\", "/")
def update_workspace(): def update_workspace():
workspace = global_envs.get("train.workspace", None) workspace = global_envs.get("train.workspace", None)
if not workspace: if not workspace:
return return
workspace = windows_path_adapter(workspace) workspace = path_adapter(workspace)
for name, value in global_envs.items(): for name, value in global_envs.items():
if isinstance(value, str): if isinstance(value, str):
value = value.replace("{workspace}", workspace) value = value.replace("{workspace}", workspace)
value = windows_path_adapter(value) value = windows_path_converter(value)
global_envs[name] = value global_envs[name] = value
......
...@@ -132,11 +132,11 @@ def cluster_engine(args): ...@@ -132,11 +132,11 @@ def cluster_engine(args):
if not workspace: if not workspace:
return return
path = envs.windows_path_adapter(workspace) path = envs.path_adapter(workspace)
for name, value in cluster_envs.items(): for name, value in cluster_envs.items():
if isinstance(value, str): if isinstance(value, str):
value = value.replace("{workspace}", path) value = value.replace("{workspace}", path)
value = envs.windows_path_adapter(value) value = envs.windows_path_converter(value)
cluster_envs[name] = value cluster_envs[name] = value
def master(): def master():
...@@ -240,7 +240,7 @@ def local_mpi_engine(args): ...@@ -240,7 +240,7 @@ def local_mpi_engine(args):
def get_abs_model(model): def get_abs_model(model):
if model.startswith("paddlerec."): if model.startswith("paddlerec."):
dir = envs.windows_path_adapter(model) dir = envs.path_adapter(model)
path = os.path.join(dir, "config.yaml") path = os.path.join(dir, "config.yaml")
else: else:
if not os.path.isfile(model): if not os.path.isfile(model):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册