提交 c762d6cb 编写于 作者: T tangwei

merge yaml two to one

上级 3b612ee4
......@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer):
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
mode = envs.get_runtime_environ("trainer.strategy")
mode = envs.get_runtime_environ("train.trainer.strategy")
assert mode in ["async", "geo", "sync", "half_async"]
strategy = None
......
......@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader"
inputs = self.model.get_inputs()
threads = int(envs.get_runtime_environ("trainer.threads"))
threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
......
......@@ -50,7 +50,7 @@ def get_runtime_environ(key):
return os.getenv(key, None)
def get_trainer():
train_mode = get_runtime_environ("trainer.trainer")
train_mode = get_runtime_environ("train.trainer.trainer")
return train_mode
......
......@@ -33,7 +33,7 @@ def set_runtime_envs(cluster_envs, engine_yaml):
need_print = {}
for k, v in os.environ.items():
if k.startswith("trainer."):
if k.startswith("train.trainer."):
need_print[k] = v
print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
......@@ -55,9 +55,9 @@ def single_engine(args):
print("use single engine to run model: {}".format(args.model))
single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2"
single_envs["trainer.engine"] = "single"
single_envs["train.trainer.trainer"] = "SingleTrainer"
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single"
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -67,8 +67,8 @@ def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster"
cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["train.trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model)
......@@ -79,7 +79,7 @@ def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model)
......@@ -95,10 +95,10 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.trainer"] = "ClusterTrainer"
cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.model)
......@@ -118,9 +118,9 @@ def local_mpi_engine(args):
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {}
cluster_envs["mpirun"] = mpi
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.engine"] = "local_cluster"
set_runtime_envs(cluster_envs, args.model)
launch = LocalMPIEngine(cluster_envs, args.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册