提交 a12fe9da 编写于 作者: T tangwei

fix windows adapter

上级 cc3d47cf
......@@ -59,8 +59,7 @@ class TrainerFactory(object):
@staticmethod
def create(config):
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
envs.set_global_envs(_config, True)
trainer = TrainerFactory._build_trainer(config)
return trainer
......
......@@ -27,8 +27,7 @@ class ReaderBase(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
envs.set_global_envs(_config, True)
@abc.abstractmethod
def init(self):
......@@ -46,8 +45,7 @@ class SlotReader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
envs.set_global_envs(_config, True)
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
......
......@@ -68,7 +68,7 @@ def get_fleet_mode():
return fleet_mode
def set_global_envs(envs):
def set_global_envs(envs, adapter):
assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs):
......@@ -92,6 +92,10 @@ def set_global_envs(envs):
fatten_env_namespace([], envs)
if adapter:
workspace_adapter()
os_path_adapter()
def get_global_env(env_name, default_value=None, namespace=None):
"""
......@@ -106,7 +110,7 @@ def get_global_envs():
return global_envs
def path_adapter(path):
def paddlerec_adapter(path):
if path.startswith("paddlerec."):
package = get_runtime_environ("PACKAGE_BASE")
l_p = path.split("paddlerec.")[1].replace(".", "/")
......@@ -115,23 +119,25 @@ def path_adapter(path):
return path
def windows_path_converter(path):
if get_platform() == "WINDOWS":
return path.replace("/", "\\")
else:
return path.replace("\\", "/")
def os_path_adapter():
for name, value in global_envs.items():
if isinstance(value, str):
if get_platform() == "WINDOWS":
value = value.replace("/", "\\")
else:
value = value.replace("\\", "/")
global_envs[name] = value
def update_workspace():
def workspace_adapter():
workspace = global_envs.get("workspace")
if not workspace:
return
workspace = path_adapter(workspace)
workspace = paddlerec_adapter(workspace)
for name, value in global_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", workspace)
value = windows_path_converter(value)
global_envs[name] = value
......
......@@ -197,8 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
envs.set_global_envs(_config)
envs.update_workspace()
envs.set_global_envs(_config, True)
@abc.abstractmethod
def init(self):
......
......@@ -260,18 +260,6 @@ def single_infer_engine(args):
def cluster_engine(args):
def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None)
if not workspace:
return
path = envs.path_adapter(workspace)
for name, value in cluster_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", path)
value = envs.windows_path_converter(value)
cluster_envs[name] = value
def master():
role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine
......@@ -280,8 +268,6 @@ def cluster_engine(args):
flattens["engine_role"] = role
flattens["engine_run_config"] = args.model
flattens["engine_temp_path"] = tempfile.mkdtemp()
update_workspace(flattens)
envs.set_runtime_environs(flattens)
print(envs.pretty_print_envs(flattens, ("Submit Envs", "Value")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册