提交 a5b27465 编写于 作者: T tangwei

update setup.py

上级 ebd1d64f
......@@ -32,7 +32,10 @@ class TrainerFactory(object):
def _build_trainer(config, yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_runtime_envion("train.trainer")
train_mode = envs.get_global_env("train.strategy.mode")
if train_mode is not None:
train_mode = envs.get_runtime_envion("train.trainer")
if train_mode == "SingleTraining":
trainer = SingleTrainer(yaml_path)
......
......@@ -44,6 +44,12 @@ class ClusterTrainer(TranspileTrainer):
def build_strategy(self):
mode = envs.get_global_env("train.strategy.mode")
if mode is None:
mode = envs.get_runtime_envion("train.strategy.mode")
assert mode is not None
strategy = None
if mode == "async":
......
trainer: "LocalClusterTraining"
pserver_num: 2
trainer_num: 2
start_port: 36001
log_dirname: "logs"
strategy:
mode: "async"
......@@ -15,7 +15,6 @@
train:
threads: 12
epochs: 10
trainer: "single_training.yaml"
reader:
mode: "dataset"
......
......@@ -60,6 +60,7 @@ if __name__ == "__main__":
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
local_cluster_engine(cluster_envs, args.model)
elif args.engine == "LocalMPI":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册