提交 3b612ee4 编写于 作者: T tangwei

merge yaml two to one

上级 8b9b9cf9
......@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer):
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
mode = envs.get_runtime_envion("trainer.strategy")
mode = envs.get_runtime_environ("trainer.strategy")
assert mode in ["async", "geo", "sync", "half_async"]
strategy = None
......
......@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader"
inputs = self.model.get_inputs()
threads = int(envs.get_runtime_envion("trainer.threads"))
threads = int(envs.get_runtime_environ("trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
......
......@@ -49,7 +49,6 @@ def set_runtime_environs(environs):
def get_runtime_environ(key):
return os.getenv(key, None)
def get_trainer():
train_mode = get_runtime_environ("trainer.trainer")
return train_mode
......
......@@ -58,7 +58,7 @@ def single_engine(args):
single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2"
single_envs["trainer.engine"] = "single"
set_runtime_envs(single_envs, args.engine_extras)
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -80,7 +80,7 @@ def cluster_mpi_engine(args):
cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.engine_extras)
set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -101,7 +101,7 @@ def local_cluster_engine(args):
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras)
set_runtime_envs(cluster_envs, args.model)
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
......@@ -122,7 +122,7 @@ def local_mpi_engine(args):
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster"
set_runtime_envs(cluster_envs, args.engine_extras)
set_runtime_envs(cluster_envs, args.model)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
......@@ -142,7 +142,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str)
parser.add_argument("-ex", "--engine_extras", default=None, type=str)
args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册