diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 376f4388dab6b2590f34210a1abe2ad6c76b03af..11752bcab719666bb785c6f6e8378237198ff532 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -32,7 +32,10 @@ class TrainerFactory(object): def _build_trainer(config, yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) - train_mode = envs.get_runtime_envion("train.trainer") + train_mode = envs.get_global_env("train.strategy.mode") + + if train_mode is not None: + train_mode = envs.get_runtime_envion("train.trainer") if train_mode == "SingleTraining": trainer = SingleTrainer(yaml_path) diff --git a/fleetrec/core/trainers/cluster_trainer.py b/fleetrec/core/trainers/cluster_trainer.py index c8d05d0507c4db5bb38f2cce51900b2dfc686863..67dab6136749624230b4d6d740bb00e04ae53450 100644 --- a/fleetrec/core/trainers/cluster_trainer.py +++ b/fleetrec/core/trainers/cluster_trainer.py @@ -44,6 +44,12 @@ class ClusterTrainer(TranspileTrainer): def build_strategy(self): mode = envs.get_global_env("train.strategy.mode") + + if mode is None: + mode = envs.get_runtime_envion("train.strategy.mode") + + assert mode is not None + strategy = None if mode == "async": diff --git a/fleetrec/examples/build_in/cluster_training_local.yaml b/fleetrec/examples/build_in/cluster_training_local.yaml deleted file mode 100644 index bf878e8c57cfa88f035ddd7c5d52904ad739f39c..0000000000000000000000000000000000000000 --- a/fleetrec/examples/build_in/cluster_training_local.yaml +++ /dev/null @@ -1,10 +0,0 @@ - -trainer: "LocalClusterTraining" - -pserver_num: 2 -trainer_num: 2 -start_port: 36001 -log_dirname: "logs" - -strategy: - mode: "async" diff --git a/fleetrec/examples/build_in/ctr-dnn_train.yaml b/fleetrec/examples/build_in/ctr-dnn_train.yaml index 5c4af64a840e62c26a3be1b02f57b4489333536b..a1c1eaa57a36a02947f72c336990e72a6a26e7c2 100644 --- a/fleetrec/examples/build_in/ctr-dnn_train.yaml +++ b/fleetrec/examples/build_in/ctr-dnn_train.yaml @@ -15,7 +15,6 @@ train: threads: 12 epochs: 10 - trainer: "single_training.yaml" reader: mode: "dataset" diff --git a/fleetrec/run.py b/fleetrec/run.py index 2dc8e1dc90abb655c470fd879dbe1f62e1b55d6a..db1b56f2adb4a0db0c6ae8cd8aa9d0e093bf4013 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -60,6 +60,7 @@ if __name__ == "__main__": cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" cluster_envs["train.trainer"] = "ClusterTraining" + cluster_envs["train.strategy.mode"] = "async" local_cluster_engine(cluster_envs, args.model) elif args.engine == "LocalMPI":