提交 741fbbd2 编写于 作者: T tangwei

windows path adapt

上级 d536f99f
...@@ -89,18 +89,26 @@ def get_global_envs(): ...@@ -89,18 +89,26 @@ def get_global_envs():
return global_envs return global_envs
def windows_path_adapter(path):
def adapt(l_p):
if get_platform() == "WINDOWS":
adapted_p = l_p.split("paddlerec.")[1].replace(".", "\\")
else:
adapted_p = l_p.split("paddlerec.")[1].replace(".", "/")
return adapted_p
if path.startswith("paddlerec."):
package = get_runtime_environ("PACKAGE_BASE")
return os.path.join(package, adapt(path))
else:
return adapt(path)
def update_workspace(): def update_workspace():
workspace = global_envs.get("train.workspace", None) workspace = global_envs.get("train.workspace", None)
if not workspace: if not workspace:
return return
path = windows_path_adapter(workspace)
# is fleet inner models
if workspace.startswith("paddlerec."):
fleet_package = get_runtime_environ("PACKAGE_BASE")
workspace_dir = workspace.split("paddlerec.")[1].replace(".", "/")
path = os.path.join(fleet_package, workspace_dir)
else:
path = workspace
for name, value in global_envs.items(): for name, value in global_envs.items():
if isinstance(value, str): if isinstance(value, str):
......
...@@ -129,17 +129,10 @@ def single_engine(args): ...@@ -129,17 +129,10 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
def update_workspace(cluster_envs): def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None) workspace = cluster_envs.get("engine_workspace", None)
if not workspace: if not workspace:
return return
path = envs.windows_path_adapter(workspace)
# is fleet inner models
if workspace.startswith("paddlerec."):
fleet_package = envs.get_runtime_environ("PACKAGE_BASE")
workspace_dir = workspace.split("paddlerec.")[1].replace(".", "/")
path = os.path.join(fleet_package, workspace_dir)
else:
path = workspace
for name, value in cluster_envs.items(): for name, value in cluster_envs.items():
if isinstance(value, str): if isinstance(value, str):
value = value.replace("{workspace}", path) value = value.replace("{workspace}", path)
...@@ -246,9 +239,8 @@ def local_mpi_engine(args): ...@@ -246,9 +239,8 @@ def local_mpi_engine(args):
def get_abs_model(model): def get_abs_model(model):
if model.startswith("paddlerec."): if model.startswith("paddlerec."):
fleet_base = envs.get_runtime_environ("PACKAGE_BASE") dir = envs.windows_path_adapter(model)
workspace_dir = model.split("paddlerec.")[1].replace(".", "/") path = os.path.join(dir, "config.yaml")
path = os.path.join(fleet_base, workspace_dir, "config.yaml")
else: else:
if not os.path.isfile(model): if not os.path.isfile(model):
raise IOError("model config: {} invalid".format(model)) raise IOError("model config: {} invalid".format(model))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册