From c762d6cb0fe8e98fd6b16868929c4acd468eff62 Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 21 Apr 2020 11:39:26 +0800 Subject: [PATCH] merge yaml two to one --- fleetrec/core/trainers/cluster_trainer.py | 2 +- fleetrec/core/trainers/transpiler_trainer.py | 2 +- fleetrec/core/utils/envs.py | 2 +- fleetrec/run.py | 26 ++++++++++---------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/fleetrec/core/trainers/cluster_trainer.py b/fleetrec/core/trainers/cluster_trainer.py index 6e7c064f..a2dd11ab 100644 --- a/fleetrec/core/trainers/cluster_trainer.py +++ b/fleetrec/core/trainers/cluster_trainer.py @@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer): self.regist_context_processor('terminal_pass', self.terminal) def build_strategy(self): - mode = envs.get_runtime_environ("trainer.strategy") + mode = envs.get_runtime_environ("train.trainer.strategy") assert mode in ["async", "geo", "sync", "half_async"] strategy = None diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index b0fb686e..ed3679c0 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -39,7 +39,7 @@ class TranspileTrainer(Trainer): namespace = "train.reader" inputs = self.model.get_inputs() - threads = int(envs.get_runtime_environ("trainer.threads")) + threads = int(envs.get_runtime_environ("train.trainer.threads")) batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index fcf1e1ae..991ede19 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -50,7 +50,7 @@ def get_runtime_environ(key): return os.getenv(key, None) def get_trainer(): - train_mode = get_runtime_environ("trainer.trainer") + train_mode = get_runtime_environ("train.trainer.trainer") return train_mode diff --git a/fleetrec/run.py b/fleetrec/run.py index 26650e7b..f49d85fa 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -33,7 +33,7 @@ def set_runtime_envs(cluster_envs, engine_yaml): need_print = {} for k, v in os.environ.items(): - if k.startswith("trainer."): + if k.startswith("train.trainer."): need_print[k] = v print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value"))) @@ -55,9 +55,9 @@ def single_engine(args): print("use single engine to run model: {}".format(args.model)) single_envs = {} - single_envs["trainer.trainer"] = "SingleTrainer" - single_envs["trainer.threads"] = "2" - single_envs["trainer.engine"] = "single" + single_envs["train.trainer.trainer"] = "SingleTrainer" + single_envs["train.trainer.threads"] = "2" + single_envs["train.trainer.engine"] = "single" set_runtime_envs(single_envs, args.model) trainer = TrainerFactory.create(args.model) return trainer @@ -67,8 +67,8 @@ def cluster_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.engine"] = "cluster" + cluster_envs["train.trainer.trainer"] = "ClusterTrainer" + cluster_envs["train.trainer.engine"] = "cluster" set_runtime_envs(cluster_envs, args.model) trainer = TrainerFactory.create(args.model) @@ -79,7 +79,7 @@ def cluster_mpi_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["trainer.trainer"] = "CtrCodingTrainer" + cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer" set_runtime_envs(cluster_envs, args.model) trainer = TrainerFactory.create(args.model) @@ -95,10 +95,10 @@ def local_cluster_engine(args): cluster_envs["worker_num"] = 1 cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" - cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.strategy"] = "async" - cluster_envs["trainer.threads"] = "2" - cluster_envs["trainer.engine"] = "local_cluster" + cluster_envs["train.trainer.trainer"] = "ClusterTrainer" + cluster_envs["train.trainer.strategy"] = "async" + cluster_envs["train.trainer.threads"] = "2" + cluster_envs["train.trainer.engine"] = "local_cluster" cluster_envs["CPU_NUM"] = "2" set_runtime_envs(cluster_envs, args.model) @@ -118,9 +118,9 @@ def local_mpi_engine(args): raise RuntimeError("can not find mpirun, please check environment") cluster_envs = {} cluster_envs["mpirun"] = mpi - cluster_envs["trainer.trainer"] = "CtrCodingTrainer" + cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer" cluster_envs["log_dir"] = "logs" - cluster_envs["trainer.engine"] = "local_cluster" + cluster_envs["train.trainer.engine"] = "local_cluster" set_runtime_envs(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model) -- GitLab