diff --git a/fleetrec/core/trainers/cluster_trainer.py b/fleetrec/core/trainers/cluster_trainer.py index 6e7c064f40c3a568df437a42573d6e7c693736a2..a2dd11abab1e4617750e749aab7353b6655d659e 100644 --- a/fleetrec/core/trainers/cluster_trainer.py +++ b/fleetrec/core/trainers/cluster_trainer.py @@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer): self.regist_context_processor('terminal_pass', self.terminal) def build_strategy(self): - mode = envs.get_runtime_environ("trainer.strategy") + mode = envs.get_runtime_environ("train.trainer.strategy") assert mode in ["async", "geo", "sync", "half_async"] strategy = None diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index b0fb686e0fb5c83bee998beb29c082a34291acd2..ed3679c0734266effa8ad23ae8db2199b06d0790 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -39,7 +39,7 @@ class TranspileTrainer(Trainer): namespace = "train.reader" inputs = self.model.get_inputs() - threads = int(envs.get_runtime_environ("trainer.threads")) + threads = int(envs.get_runtime_environ("train.trainer.threads")) batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index fcf1e1aef7103412e85152f6ecfb374a049f22eb..991ede19190d57f906d0d561cee01be929b935bd 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -50,7 +50,7 @@ def get_runtime_environ(key): return os.getenv(key, None) def get_trainer(): - train_mode = get_runtime_environ("trainer.trainer") + train_mode = get_runtime_environ("train.trainer.trainer") return train_mode diff --git a/fleetrec/run.py b/fleetrec/run.py index 26650e7b6317a06a82636b91359869f860acf084..f49d85fa75048768bb70f40a30eed8e9cb80a5c1 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -33,7 +33,7 @@ def set_runtime_envs(cluster_envs, engine_yaml): need_print = {} for k, v in os.environ.items(): - if k.startswith("trainer."): + if k.startswith("train.trainer."): need_print[k] = v print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value"))) @@ -55,9 +55,9 @@ def single_engine(args): print("use single engine to run model: {}".format(args.model)) single_envs = {} - single_envs["trainer.trainer"] = "SingleTrainer" - single_envs["trainer.threads"] = "2" - single_envs["trainer.engine"] = "single" + single_envs["train.trainer.trainer"] = "SingleTrainer" + single_envs["train.trainer.threads"] = "2" + single_envs["train.trainer.engine"] = "single" set_runtime_envs(single_envs, args.model) trainer = TrainerFactory.create(args.model) return trainer @@ -67,8 +67,8 @@ def cluster_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.engine"] = "cluster" + cluster_envs["train.trainer.trainer"] = "ClusterTrainer" + cluster_envs["train.trainer.engine"] = "cluster" set_runtime_envs(cluster_envs, args.model) trainer = TrainerFactory.create(args.model) @@ -79,7 +79,7 @@ def cluster_mpi_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["trainer.trainer"] = "CtrCodingTrainer" + cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer" set_runtime_envs(cluster_envs, args.model) trainer = TrainerFactory.create(args.model) @@ -95,10 +95,10 @@ def local_cluster_engine(args): cluster_envs["worker_num"] = 1 cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" - cluster_envs["trainer.trainer"] = "ClusterTrainer" - cluster_envs["trainer.strategy"] = "async" - cluster_envs["trainer.threads"] = "2" - cluster_envs["trainer.engine"] = "local_cluster" + cluster_envs["train.trainer.trainer"] = "ClusterTrainer" + cluster_envs["train.trainer.strategy"] = "async" + cluster_envs["train.trainer.threads"] = "2" + cluster_envs["train.trainer.engine"] = "local_cluster" cluster_envs["CPU_NUM"] = "2" set_runtime_envs(cluster_envs, args.model) @@ -118,9 +118,9 @@ def local_mpi_engine(args): raise RuntimeError("can not find mpirun, please check environment") cluster_envs = {} cluster_envs["mpirun"] = mpi - cluster_envs["trainer.trainer"] = "CtrCodingTrainer" + cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer" cluster_envs["log_dir"] = "logs" - cluster_envs["trainer.engine"] = "local_cluster" + cluster_envs["train.trainer.engine"] = "local_cluster" set_runtime_envs(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)