diff --git a/core/factory.py b/core/factory.py index 3c534aaa14cdd5c2ac056d0f78554a4ebc69c3f4..9430c88283800e69db7043aa141b6f735212c79f 100755 --- a/core/factory.py +++ b/core/factory.py @@ -59,7 +59,7 @@ class TrainerFactory(object): @staticmethod def create(config): _config = envs.load_yaml(config) - envs.set_global_envs(_config, True) + envs.set_global_envs(_config) trainer = TrainerFactory._build_trainer(config) return trainer diff --git a/core/reader.py b/core/reader.py index 6c2af005f23b790e1686f3e2fd894399a3dbffaf..589e6c192330baf04c202952cefb04177f3e4297 100755 --- a/core/reader.py +++ b/core/reader.py @@ -26,7 +26,7 @@ class ReaderBase(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config, True) + envs.set_global_envs(_config) @abc.abstractmethod def init(self): @@ -44,7 +44,7 @@ class SlotReader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config, True) + envs.set_global_envs(_config) def init(self, sparse_slots, dense_slots, padding=0): from operator import mul diff --git a/core/trainer.py b/core/trainer.py index 53a9cc1eb89e6eda2d08bc8f684bc4e8e4441e5c..83a7ea1ac98da07fd4ee729a57957e50b36e87a2 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -16,7 +16,6 @@ import abc import os import time import sys -import yaml import traceback from paddle import fluid @@ -74,11 +73,14 @@ class Trainer(object): phase_names = envs.get_global_env( "runner." + self._runner_name + ".phases", None) + + _config = envs.load_yaml(config) + phases = [] if phase_names is None: - phases = envs.get_global_env("phase") + phases = _config.get("phase") else: - for phase in envs.get_global_env("phase"): + for phase in _config.get("phase"): if phase["name"] in phase_names: phases.append(phase) @@ -244,15 +246,3 @@ class Trainer(object): self.context_process(self._context) if self._context['is_exit']: break - - -def user_define_engine(engine_yaml): - _config = envs.load_yaml(engine_yaml) - envs.set_runtime_environs(_config) - train_location = envs.get_global_env("engine.file") - train_dirname = os.path.dirname(train_location) - base_name = os.path.splitext(os.path.basename(train_location))[0] - sys.path.append(train_dirname) - trainer_class = envs.lazy_instance_by_fliename(base_name, - "UserDefineTraining") - return trainer_class diff --git a/core/utils/envs.py b/core/utils/envs.py index 91e97f7a27e7a16914b0656f58fb7f4f80f698c2..526d098e1e1a2d1e206d98cbf9689714c2c00636 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -20,9 +20,8 @@ import socket import sys import traceback -import yaml - global_envs = {} +global_envs_flatten = {} def flatten_environs(envs, separator="."): @@ -68,7 +67,7 @@ def get_fleet_mode(): return fleet_mode -def set_global_envs(envs, adapter): +def set_global_envs(envs): assert isinstance(envs, dict) def fatten_env_namespace(namespace_nests, local_envs): @@ -92,10 +91,9 @@ def set_global_envs(envs, adapter): fatten_env_namespace([], envs) - if adapter: - workspace_adapter() - os_path_adapter() - reader_adapter() + workspace_adapter() + os_path_adapter() + reader_adapter() def get_global_env(env_name, default_value=None, namespace=None): @@ -134,6 +132,7 @@ def workspace_adapter(): workspace = global_envs.get("workspace") if not workspace: return + workspace = paddlerec_adapter(workspace) for name, value in global_envs.items(): diff --git a/doc/design.md b/doc/design.md index 18cdb03cbaa0e67412492913b01afe7281ec4bae..d9eeaf6f16a38ba445629da7d800cfb1bf79416d 100644 --- a/doc/design.md +++ b/doc/design.md @@ -197,7 +197,7 @@ class Reader(dg.MultiSlotDataGenerator): def __init__(self, config): dg.MultiSlotDataGenerator.__init__(self) _config = envs.load_yaml(config) - envs.set_global_envs(_config, True) + envs.set_global_envs(_config) @abc.abstractmethod def init(self): diff --git a/run.py b/run.py index 2e3b822ac953e9f1d73327a2de641ee2cdba8683..881abcc6ae6bc8022d76b9672a9fabb424ebcebf 100755 --- a/run.py +++ b/run.py @@ -110,7 +110,6 @@ def get_modes(running_config): def get_engine(args, running_config, mode): transpiler = get_transpiler() - _envs = envs.load_yaml(args.model) engine_class = ".".join(["runner", mode, "class"]) engine_device = ".".join(["runner", mode, "device"]) @@ -122,11 +121,14 @@ def get_engine(args, running_config, mode): mode, engine_class)) device = running_config.get(engine_device, None) + engine = engine.upper() + device = device.upper() + if device is None: print("not find device be specified in yaml, set CPU as default") device = "CPU" - if device.upper() == "GPU": + if device == "GPU": selected_gpus = running_config.get(device_gpu_choices, None) if selected_gpus is None: @@ -142,7 +144,6 @@ def get_engine(args, running_config, mode): if selected_gpus_num > 1: engine = "LOCAL_CLUSTER" - engine = engine.upper() if engine not in engine_choices: raise ValueError("{} can not be chosen in {}".format(engine_class, engine_choices)) @@ -180,9 +181,7 @@ def set_runtime_envs(cluster_envs, engine_yaml): def single_train_engine(args): - _envs = envs.load_yaml(args.model) run_extras = get_all_inters_from_yaml(args.model, ["runner."]) - mode = envs.get_runtime_environ("mode") trainer_class = ".".join(["runner", mode, "trainer_class"]) fleet_class = ".".join(["runner", mode, "fleet_mode"]) @@ -435,7 +434,7 @@ def local_mpi_engine(args): def get_abs_model(model): if model.startswith("paddlerec."): - dir = envs.path_adapter(model) + dir = envs.paddlerec_adapter(model) path = os.path.join(dir, "config.yaml") else: if not os.path.isfile(model): @@ -453,13 +452,12 @@ if __name__ == "__main__": envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) args = parser.parse_args() - model_name = args.model.split('.')[-1] args.model = get_abs_model(args.model) if not validation.yaml_validation(args.model): sys.exit(-1) - engine_registry() + engine_registry() running_config = get_all_inters_from_yaml(args.model, ["mode", "runner."]) modes = get_modes(running_config)