提交 3b612ee4 编写于 作者: T tangwei

merge yaml two to one

上级 8b9b9cf9
...@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer): ...@@ -43,7 +43,7 @@ class ClusterTrainer(TranspileTrainer):
self.regist_context_processor('terminal_pass', self.terminal) self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self): def build_strategy(self):
mode = envs.get_runtime_envion("trainer.strategy") mode = envs.get_runtime_environ("trainer.strategy")
assert mode in ["async", "geo", "sync", "half_async"] assert mode in ["async", "geo", "sync", "half_async"]
strategy = None strategy = None
......
...@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer): ...@@ -39,7 +39,7 @@ class TranspileTrainer(Trainer):
namespace = "train.reader" namespace = "train.reader"
inputs = self.model.get_inputs() inputs = self.model.get_inputs()
threads = int(envs.get_runtime_envion("trainer.threads")) threads = int(envs.get_runtime_environ("trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
......
...@@ -49,7 +49,6 @@ def set_runtime_environs(environs): ...@@ -49,7 +49,6 @@ def set_runtime_environs(environs):
def get_runtime_environ(key): def get_runtime_environ(key):
return os.getenv(key, None) return os.getenv(key, None)
def get_trainer(): def get_trainer():
train_mode = get_runtime_environ("trainer.trainer") train_mode = get_runtime_environ("trainer.trainer")
return train_mode return train_mode
......
...@@ -58,7 +58,7 @@ def single_engine(args): ...@@ -58,7 +58,7 @@ def single_engine(args):
single_envs["trainer.trainer"] = "SingleTrainer" single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2" single_envs["trainer.threads"] = "2"
single_envs["trainer.engine"] = "single" single_envs["trainer.engine"] = "single"
set_runtime_envs(single_envs, args.engine_extras) set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -80,7 +80,7 @@ def cluster_mpi_engine(args): ...@@ -80,7 +80,7 @@ def cluster_mpi_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["trainer.trainer"] = "CtrCodingTrainer" cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -101,7 +101,7 @@ def local_cluster_engine(args): ...@@ -101,7 +101,7 @@ def local_cluster_engine(args):
cluster_envs["trainer.engine"] = "local_cluster" cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2" cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.model)
launch = LocalClusterEngine(cluster_envs, args.model) launch = LocalClusterEngine(cluster_envs, args.model)
return launch return launch
...@@ -122,7 +122,7 @@ def local_mpi_engine(args): ...@@ -122,7 +122,7 @@ def local_mpi_engine(args):
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster" cluster_envs["trainer.engine"] = "local_cluster"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.model)
launch = LocalMPIEngine(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)
return launch return launch
...@@ -142,7 +142,6 @@ if __name__ == "__main__": ...@@ -142,7 +142,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str) parser.add_argument("-e", "--engine", type=str)
parser.add_argument("-ex", "--engine_extras", default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册