提交 a12fe9da 编写于 作者: T tangwei

fix windows adapter

上级 cc3d47cf
...@@ -59,8 +59,7 @@ class TrainerFactory(object): ...@@ -59,8 +59,7 @@ class TrainerFactory(object):
@staticmethod @staticmethod
def create(config): def create(config):
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config) envs.set_global_envs(_config, True)
envs.update_workspace()
trainer = TrainerFactory._build_trainer(config) trainer = TrainerFactory._build_trainer(config)
return trainer return trainer
......
...@@ -27,8 +27,7 @@ class ReaderBase(dg.MultiSlotDataGenerator): ...@@ -27,8 +27,7 @@ class ReaderBase(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config) envs.set_global_envs(_config, True)
envs.update_workspace()
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
...@@ -46,8 +45,7 @@ class SlotReader(dg.MultiSlotDataGenerator): ...@@ -46,8 +45,7 @@ class SlotReader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config) envs.set_global_envs(_config, True)
envs.update_workspace()
def init(self, sparse_slots, dense_slots, padding=0): def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul from operator import mul
......
...@@ -68,7 +68,7 @@ def get_fleet_mode(): ...@@ -68,7 +68,7 @@ def get_fleet_mode():
return fleet_mode return fleet_mode
def set_global_envs(envs): def set_global_envs(envs, adapter):
assert isinstance(envs, dict) assert isinstance(envs, dict)
def fatten_env_namespace(namespace_nests, local_envs): def fatten_env_namespace(namespace_nests, local_envs):
...@@ -92,6 +92,10 @@ def set_global_envs(envs): ...@@ -92,6 +92,10 @@ def set_global_envs(envs):
fatten_env_namespace([], envs) fatten_env_namespace([], envs)
if adapter:
workspace_adapter()
os_path_adapter()
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
""" """
...@@ -106,7 +110,7 @@ def get_global_envs(): ...@@ -106,7 +110,7 @@ def get_global_envs():
return global_envs return global_envs
def path_adapter(path): def paddlerec_adapter(path):
if path.startswith("paddlerec."): if path.startswith("paddlerec."):
package = get_runtime_environ("PACKAGE_BASE") package = get_runtime_environ("PACKAGE_BASE")
l_p = path.split("paddlerec.")[1].replace(".", "/") l_p = path.split("paddlerec.")[1].replace(".", "/")
...@@ -115,23 +119,25 @@ def path_adapter(path): ...@@ -115,23 +119,25 @@ def path_adapter(path):
return path return path
def windows_path_converter(path): def os_path_adapter():
for name, value in global_envs.items():
if isinstance(value, str):
if get_platform() == "WINDOWS": if get_platform() == "WINDOWS":
return path.replace("/", "\\") value = value.replace("/", "\\")
else: else:
return path.replace("\\", "/") value = value.replace("\\", "/")
global_envs[name] = value
def update_workspace(): def workspace_adapter():
workspace = global_envs.get("workspace") workspace = global_envs.get("workspace")
if not workspace: if not workspace:
return return
workspace = path_adapter(workspace) workspace = paddlerec_adapter(workspace)
for name, value in global_envs.items(): for name, value in global_envs.items():
if isinstance(value, str): if isinstance(value, str):
value = value.replace("{workspace}", workspace) value = value.replace("{workspace}", workspace)
value = windows_path_converter(value)
global_envs[name] = value global_envs[name] = value
......
...@@ -197,8 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -197,8 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config): def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self) dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
envs.set_global_envs(_config) envs.set_global_envs(_config, True)
envs.update_workspace()
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
......
...@@ -260,18 +260,6 @@ def single_infer_engine(args): ...@@ -260,18 +260,6 @@ def single_infer_engine(args):
def cluster_engine(args): def cluster_engine(args):
def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None)
if not workspace:
return
path = envs.path_adapter(workspace)
for name, value in cluster_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", path)
value = envs.windows_path_converter(value)
cluster_envs[name] = value
def master(): def master():
role = "MASTER" role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine from paddlerec.core.engine.cluster.cluster import ClusterEngine
...@@ -280,8 +268,6 @@ def cluster_engine(args): ...@@ -280,8 +268,6 @@ def cluster_engine(args):
flattens["engine_role"] = role flattens["engine_role"] = role
flattens["engine_run_config"] = args.model flattens["engine_run_config"] = args.model
flattens["engine_temp_path"] = tempfile.mkdtemp() flattens["engine_temp_path"] = tempfile.mkdtemp()
update_workspace(flattens)
envs.set_runtime_environs(flattens) envs.set_runtime_environs(flattens)
print(envs.pretty_print_envs(flattens, ("Submit Envs", "Value"))) print(envs.pretty_print_envs(flattens, ("Submit Envs", "Value")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册