From a5b274653d82ee49de51f55640709a828308b192 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 15 Apr 2020 19:33:35 +0800 Subject: [PATCH] update setup.py --- fleetrec/core/factory.py | 5 ++++- fleetrec/core/trainers/cluster_trainer.py | 6 ++++++ fleetrec/examples/build_in/cluster_training_local.yaml | 10 ---------- fleetrec/examples/build_in/ctr-dnn_train.yaml | 1 - fleetrec/run.py | 1 + 5 files changed, 11 insertions(+), 12 deletions(-) delete mode 100644 fleetrec/examples/build_in/cluster_training_local.yaml diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 376f4388..11752bca 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -32,7 +32,10 @@ class TrainerFactory(object): def _build_trainer(config, yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) - train_mode = envs.get_runtime_envion("train.trainer") + train_mode = envs.get_global_env("train.strategy.mode") + + if train_mode is not None: + train_mode = envs.get_runtime_envion("train.trainer") if train_mode == "SingleTraining": trainer = SingleTrainer(yaml_path) diff --git a/fleetrec/core/trainers/cluster_trainer.py b/fleetrec/core/trainers/cluster_trainer.py index c8d05d05..67dab613 100644 --- a/fleetrec/core/trainers/cluster_trainer.py +++ b/fleetrec/core/trainers/cluster_trainer.py @@ -44,6 +44,12 @@ class ClusterTrainer(TranspileTrainer): def build_strategy(self): mode = envs.get_global_env("train.strategy.mode") + + if mode is None: + mode = envs.get_runtime_envion("train.strategy.mode") + + assert mode is not None + strategy = None if mode == "async": diff --git a/fleetrec/examples/build_in/cluster_training_local.yaml b/fleetrec/examples/build_in/cluster_training_local.yaml deleted file mode 100644 index bf878e8c..00000000 --- a/fleetrec/examples/build_in/cluster_training_local.yaml +++ /dev/null @@ -1,10 +0,0 @@ - -trainer: "LocalClusterTraining" - -pserver_num: 2 -trainer_num: 2 -start_port: 36001 -log_dirname: "logs" - -strategy: - mode: "async" diff --git a/fleetrec/examples/build_in/ctr-dnn_train.yaml b/fleetrec/examples/build_in/ctr-dnn_train.yaml index 5c4af64a..a1c1eaa5 100644 --- a/fleetrec/examples/build_in/ctr-dnn_train.yaml +++ b/fleetrec/examples/build_in/ctr-dnn_train.yaml @@ -15,7 +15,6 @@ train: threads: 12 epochs: 10 - trainer: "single_training.yaml" reader: mode: "dataset" diff --git a/fleetrec/run.py b/fleetrec/run.py index 2dc8e1dc..db1b56f2 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -60,6 +60,7 @@ if __name__ == "__main__": cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" cluster_envs["train.trainer"] = "ClusterTraining" + cluster_envs["train.strategy.mode"] = "async" local_cluster_engine(cluster_envs, args.model) elif args.engine == "LocalMPI": -- GitLab