diff --git a/core/factory.py b/core/factory.py index 96b38b2210f94c08cab69e465923120dbd706cd4..3c534aaa14cdd5c2ac056d0f78554a4ebc69c3f4 100755 --- a/core/factory.py +++ b/core/factory.py @@ -59,8 +59,7 @@ class TrainerFactory(object): @staticmethod def create(config): _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) trainer = TrainerFactory._build_trainer(config) return trainer diff --git a/core/reader.py b/core/reader.py index dd20d588a5c36754190e708484505895d19c41a4..a05916a3a1088956554be80798674cbb23f4c32d 100755 --- a/core/reader.py +++ b/core/reader.py @@ -27,8 +27,7 @@ class ReaderBase(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) @abc.abstractmethod def init(self): @@ -46,8 +45,7 @@ class SlotReader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) def init(self, sparse_slots, dense_slots, padding=0): from operator import mul diff --git a/core/utils/envs.py b/core/utils/envs.py index bbb2a824e6d3c5268d6ae17caed99410f714f353..45619a2b051b61780316ecd511255c64bde8b8fa 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -68,7 +68,7 @@ def get_fleet_mode(): return fleet_mode -def set_global_envs(envs): +def set_global_envs(envs, adapter): assert isinstance(envs, dict) def fatten_env_namespace(namespace_nests, local_envs): @@ -92,6 +92,10 @@ def set_global_envs(envs): fatten_env_namespace([], envs) + if adapter: + workspace_adapter() + os_path_adapter() + def get_global_env(env_name, default_value=None, namespace=None): """ @@ -106,7 +110,7 @@ def get_global_envs(): return global_envs -def path_adapter(path): +def paddlerec_adapter(path): if path.startswith("paddlerec."): package = get_runtime_environ("PACKAGE_BASE") l_p = path.split("paddlerec.")[1].replace(".", "/") @@ -115,23 +119,25 @@ def path_adapter(path): return path -def windows_path_converter(path): - if get_platform() == "WINDOWS": - return path.replace("/", "\\") - else: - return path.replace("\\", "/") +def os_path_adapter(): + for name, value in global_envs.items(): + if isinstance(value, str): + if get_platform() == "WINDOWS": + value = value.replace("/", "\\") + else: + value = value.replace("\\", "/") + global_envs[name] = value -def update_workspace(): +def workspace_adapter(): workspace = global_envs.get("workspace") if not workspace: return - workspace = path_adapter(workspace) + workspace = paddlerec_adapter(workspace) for name, value in global_envs.items(): if isinstance(value, str): value = value.replace("{workspace}", workspace) - value = windows_path_converter(value) global_envs[name] = value diff --git a/doc/design.md b/doc/design.md index 164fa811b1ae800c0d8b84e8fdebfc7f59ab5430..18cdb03cbaa0e67412492913b01afe7281ec4bae 100644 --- a/doc/design.md +++ b/doc/design.md @@ -197,8 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config) - envs.update_workspace() + envs.set_global_envs(_config, True) @abc.abstractmethod def init(self): diff --git a/run.py b/run.py index b9256c268722b5250dac604c8843340ced11dd75..2e3b822ac953e9f1d73327a2de641ee2cdba8683 100755 --- a/run.py +++ b/run.py @@ -260,18 +260,6 @@ def single_infer_engine(args): def cluster_engine(args): - def update_workspace(cluster_envs): - workspace = cluster_envs.get("engine_workspace", None) - - if not workspace: - return - path = envs.path_adapter(workspace) - for name, value in cluster_envs.items(): - if isinstance(value, str): - value = value.replace("{workspace}", path) - value = envs.windows_path_converter(value) - cluster_envs[name] = value - def master(): role = "MASTER" from paddlerec.core.engine.cluster.cluster import ClusterEngine @@ -280,8 +268,6 @@ def cluster_engine(args): flattens["engine_role"] = role flattens["engine_run_config"] = args.model flattens["engine_temp_path"] = tempfile.mkdtemp() - update_workspace(flattens) - envs.set_runtime_environs(flattens) print(envs.pretty_print_envs(flattens, ("Submit Envs", "Value")))