From 8cf79b116a9046f144af617c327f8009e9f53079 Mon Sep 17 00:00:00 2001 From: tangwei Date: Wed, 15 Apr 2020 19:13:36 +0800 Subject: [PATCH] update setup.py --- fleetrec/core/trainers/transpiler_trainer.py | 2 +- fleetrec/run.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index e8ff15b9..02253573 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -43,7 +43,7 @@ class TranspileTrainer(Trainer): batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) - reader = os.path.join(abs_dir, '../reader', 'reader_instance.py') + reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) train_data_path = envs.get_global_env("train_data_path", None, namespace) diff --git a/fleetrec/run.py b/fleetrec/run.py index 51dbd1fd..4fe2841a 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -54,21 +54,21 @@ if __name__ == "__main__": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["train.server_num"] = 1 - cluster_envs["train.worker_num"] = 1 - cluster_envs["train.start_port"] = 36001 - cluster_envs["train.log_dir"] = "logs" - cluster_envs["train.trainer"] = "SingleTraining" + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" + cluster_envs["train.trainer"] = "ClusterTraining" local_cluster_engine(cluster_envs, args.model) elif args.engine == "LocalMPI": print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["train.server_num"] = 1 - cluster_envs["train.worker_num"] = 1 - cluster_envs["train.start_port"] = 36001 - cluster_envs["train.log_dir"] = "logs" + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" cluster_envs["train.trainer"] = "CtrTraining" local_mpi_engine(cluster_envs, args.model) -- GitLab