提交 c762d6cb 编写于 作者: T tangwei

merge yaml two to one

上级 3b612ee4
...@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer):
self.regist_context_processor('terminal_pass', self.terminal) self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self): def build_strategy(self):
mode = envs.get_runtime_environ("trainer.strategy") mode = envs.get_runtime_environ("train.trainer.strategy")
assert mode in ["async", "geo", "sync", "half_async"] assert mode in ["async", "geo", "sync", "half_async"]
strategy = None strategy = None
......
...@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer): ...@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader" namespace = "train.reader"
inputs = self.model.get_inputs() inputs = self.model.get_inputs()
threads = int(envs.get_runtime_environ("trainer.threads")) threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
......
...@@ -50,7 +50,7 @@ def get_runtime_environ(key): ...@@ -50,7 +50,7 @@ def get_runtime_environ(key):
return os.getenv(key, None) return os.getenv(key, None)
def get_trainer(): def get_trainer():
train_mode = get_runtime_environ("trainer.trainer") train_mode = get_runtime_environ("train.trainer.trainer")
return train_mode return train_mode
......
...@@ -33,7 +33,7 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -33,7 +33,7 @@ def set_runtime_envs(cluster_envs, engine_yaml):
need_print = {} need_print = {}
for k, v in os.environ.items(): for k, v in os.environ.items():
if k.startswith("trainer."): if k.startswith("train.trainer."):
need_print[k] = v need_print[k] = v
print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value"))) print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
...@@ -55,9 +55,9 @@ def single_engine(args): ...@@ -55,9 +55,9 @@ def single_engine(args):
print("use single engine to run model: {}".format(args.model)) print("use single engine to run model: {}".format(args.model))
single_envs = {} single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer" single_envs["train.trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2" single_envs["train.trainer.threads"] = "2"
single_envs["trainer.engine"] = "single" single_envs["train.trainer.engine"] = "single"
set_runtime_envs(single_envs, args.model) set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -67,8 +67,8 @@ def cluster_engine(args): ...@@ -67,8 +67,8 @@ def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster" cluster_envs["train.trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
...@@ -79,7 +79,7 @@ def cluster_mpi_engine(args): ...@@ -79,7 +79,7 @@ def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer" cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
...@@ -95,10 +95,10 @@ def local_cluster_engine(args): ...@@ -95,10 +95,10 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy"] = "async" cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2" cluster_envs["train.trainer.threads"] = "2"
cluster_envs["trainer.engine"] = "local_cluster" cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2" cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
...@@ -118,9 +118,9 @@ def local_mpi_engine(args): ...@@ -118,9 +118,9 @@ def local_mpi_engine(args):
raise RuntimeError("can not find mpirun, please check environment") raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {} cluster_envs = {}
cluster_envs["mpirun"] = mpi cluster_envs["mpirun"] = mpi
cluster_envs["trainer.trainer"] = "CtrCodingTrainer" cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster" cluster_envs["train.trainer.engine"] = "local_cluster"
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
launch = LocalMPIEngine(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册