From a12fe9dadeca7a36aa04c8ac48cd75e8efae6842 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 10 Jun 2020 17:51:45 +0800 Subject: [PATCH] fix windows adapter --- core/factory.py | 3 +-- core/reader.py | 6 ++---- core/utils/envs.py | 26 ++++++++++++++++---------- doc/design.md | 3 +-- run.py | 14 -------------- 5 files changed, 20 insertions(+), 32 deletions(-) diff --git a/core/factory.py b/core/factory.py index 96b38b22..3c534aaa 100755 --- a/core/factory.py +++ b/core/factory.py @@ -59,8 +59,7 @@ class TrainerFactory(object): @staticmethod def create(config): _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) trainer = TrainerFactory._build_trainer(config) return trainer diff --git a/core/reader.py b/core/reader.py index dd20d588..a05916a3 100755 --- a/core/reader.py +++ b/core/reader.py @@ -27,8 +27,7 @@ class ReaderBase(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) @abc.abstractmethod def init(self): @@ -46,8 +45,7 @@ class SlotReader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) def init(self, sparse_slots, dense_slots, padding=0): from operator import mul diff --git a/core/utils/envs.py b/core/utils/envs.py index bbb2a824..45619a2b 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -68,7 +68,7 @@ def get_fleet_mode(): return fleet_mode -def set_global_envs(envs): +def set_global_envs(envs, adapter): assert isinstance(envs, dict) def fatten_env_namespace(namespace_nests, local_envs): @@ -92,6 +92,10 @@ def set_global_envs(envs): fatten_env_namespace([], envs) + if adapter: + workspace_adapter() + os_path_adapter() + def get_global_env(env_name, default_value=None, namespace=None): """ @@ -106,7 +110,7 @@ def get_global_envs(): return global_envs -def path_adapter(path): +def paddlerec_adapter(path): if path.startswith("paddlerec."): package = get_runtime_environ("PACKAGE_BASE") l_p = path.split("paddlerec.")[1].replace(".", "/") @@ -115,23 +119,25 @@ def path_adapter(path): return path -def windows_path_converter(path): - if get_platform() == "WINDOWS": - return path.replace("/", "\\") - else: - return path.replace("\\", "/") +def os_path_adapter(): + for name, value in global_envs.items(): + if isinstance(value, str): + if get_platform() == "WINDOWS": + value = value.replace("/", "\\") + else: + value = value.replace("\\", "/") + global_envs[name] = value -def update_workspace(): +def workspace_adapter(): workspace = global_envs.get("workspace") if not workspace: return - workspace = path_adapter(workspace) + workspace = paddlerec_adapter(workspace) for name, value in global_envs.items(): if isinstance(value, str): value = value.replace("{workspace}", workspace) - value = windows_path_converter(value) global_envs[name] = value diff --git a/doc/design.md b/doc/design.md index 164fa811..18cdb03c 100644 --- a/doc/design.md +++ b/doc/design.md @@ -197,8 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) @abc.abstractmethod def init(self): diff --git a/run.py b/run.py index b9256c26..2e3b822a 100755 --- a/run.py +++ b/run.py @@ -260,18 +260,6 @@ def single_infer_engine(args): def cluster_engine(args): - def update_workspace(cluster_envs): - workspace = cluster_envs.get("engine_workspace", None) - - if not workspace: - return - path = envs.path_adapter(workspace) - for name, value in cluster_envs.items(): - if isinstance(value, str): - value = value.replace("{workspace}", path) - value = envs.windows_path_converter(value) - cluster_envs[name] = value - def master(): role = "MASTER" from paddlerec.core.engine.cluster.cluster import ClusterEngine @@ -280,8 +268,6 @@ def cluster_engine(args): flattens["engine_role"] = role flattens["engine_run_config"] = args.model flattens["engine_temp_path"] = tempfile.mkdtemp() - update_workspace(flattens) - envs.set_runtime_environs(flattens) print(envs.pretty_print_envs(flattens, ("Submit Envs", "Value"))) -- GitLab