From 8b9b9cf9ccbaf401d6c01d1c3c3ebf4e3e5d8fa3 Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 21 Apr 2020 11:27:38 +0800 Subject: [PATCH] merge yaml two to one --- fleetrec/core/trainer.py | 2 +- fleetrec/core/utils/envs.py | 17 ++++++++++++----- fleetrec/examples/ctr-dnn_train.yaml | 11 +++++++++++ fleetrec/examples/runtime.yaml | 13 ------------- fleetrec/run.py | 25 +++++++++++++------------ 5 files changed, 37 insertions(+), 31 deletions(-) delete mode 100644 fleetrec/examples/runtime.yaml diff --git a/fleetrec/core/trainer.py b/fleetrec/core/trainer.py index b3a6cc2e..06b5800e 100755 --- a/fleetrec/core/trainer.py +++ b/fleetrec/core/trainer.py @@ -89,7 +89,7 @@ def user_define_engine(engine_yaml): _config = yaml.load(rb.read(), Loader=yaml.FullLoader) assert _config is not None - envs.set_runtime_envions(_config) + envs.set_runtime_environs(_config) train_location = envs.get_global_env("engine.file") train_dirname = os.path.dirname(train_location) diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index e4a04440..a4bc20f2 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -18,13 +18,14 @@ import copy global_envs = {} -def set_runtime_envions(envs): +def flatten_environs(envs): + flatten_dict = {} assert isinstance(envs, dict) def fatten_env_namespace(namespace_nests, local_envs): if not isinstance(local_envs, dict): global_k = ".".join(namespace_nests) - os.environ[global_k] = str(local_envs) + flatten_dict[global_k] = str(local_envs) else: for k, v in local_envs.items(): if isinstance(v, dict): @@ -33,18 +34,24 @@ def set_runtime_envions(envs): fatten_env_namespace(nests, v) else: global_k = ".".join(namespace_nests + [k]) - os.environ[global_k] = str(v) + flatten_dict[global_k] = str(v) for k, v in envs.items(): fatten_env_namespace([k], v) + return flatten_dict -def get_runtime_envion(key): + +def set_runtime_environs(environs): + for k, v in environs.items(): + os.environ[k] = v + +def get_runtime_environ(key): return os.getenv(key, None) def get_trainer(): - train_mode = get_runtime_envion("trainer.trainer") + train_mode = get_runtime_environ("trainer.trainer") return train_mode diff --git a/fleetrec/examples/ctr-dnn_train.yaml b/fleetrec/examples/ctr-dnn_train.yaml index 3e91d52a..35cd1aeb 100644 --- a/fleetrec/examples/ctr-dnn_train.yaml +++ b/fleetrec/examples/ctr-dnn_train.yaml @@ -13,6 +13,17 @@ # limitations under the License. train: + trainer: + trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" + threads: 4 + # for cluster training + strategy: "async" + communicator: + send_queue_size: 4 + min_send_grad_num_before_recv: 4 + thread_pool_size: 5 + max_merge_var_num: 4 + epochs: 10 reader: diff --git a/fleetrec/examples/runtime.yaml b/fleetrec/examples/runtime.yaml deleted file mode 100644 index 70fa27c3..00000000 --- a/fleetrec/examples/runtime.yaml +++ /dev/null @@ -1,13 +0,0 @@ -trainer: - trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py" - threads: 4 - - # for cluster training - strategy: "async" - communicator: - send_queue_size: 4 - min_send_grad_num_before_recv: 4 - thread_pool_size: 5 - max_merge_var_num: 4 - - diff --git a/fleetrec/run.py b/fleetrec/run.py index e0afe5a1..9fbb1bba 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] def set_runtime_envs(cluster_envs, engine_yaml): - if engine_yaml is not None: + def get_engine_extras(): with open(engine_yaml, 'r') as rb: _envs = yaml.load(rb.read(), Loader=yaml.FullLoader) - else: - _envs = {} + + flattens = envs.flatten_environs(_envs) + + engine_extras = {} + for k, v in flattens.items(): + if k.startswith("train.trainer."): + engine_extras[k] = v + return engine_extras if cluster_envs is None: cluster_envs = {} - envs.set_runtime_envions(cluster_envs) - envs.set_runtime_envions(_envs) + + envs.set_runtime_environs(cluster_envs) + envs.set_runtime_environs(get_engine_extras()) need_print = {} for k, v in os.environ.items(): @@ -62,9 +69,8 @@ def cluster_engine(args): cluster_envs = {} cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.engine"] = "cluster" - set_runtime_envs(cluster_envs, args.engine_extras) + set_runtime_envs(cluster_envs, args.model) - envs.set_runtime_envions(cluster_envs) trainer = TrainerFactory.create(args.model) return trainer @@ -146,11 +152,6 @@ if __name__ == "__main__": if args.engine.upper() not in clusters: raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters)) - if args.engine_extras is not None: - if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras): - raise ValueError( - "argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras)) - which_engine = get_engine(args.engine) engine = which_engine(args) engine.run() -- GitLab