diff --git a/fleetrec/core/trainers/transpiler_trainer.py b/fleetrec/core/trainers/transpiler_trainer.py index e8ff15b9d2967f00ceb9ddf2c642db7266541ff1..02253573f0980192683404cb579eae80cda7dea8 100644 --- a/fleetrec/core/trainers/transpiler_trainer.py +++ b/fleetrec/core/trainers/transpiler_trainer.py @@ -43,7 +43,7 @@ class TranspileTrainer(Trainer): batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) - reader = os.path.join(abs_dir, '../reader', 'reader_instance.py') + reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) train_data_path = envs.get_global_env("train_data_path", None, namespace) diff --git a/fleetrec/run.py b/fleetrec/run.py index 51dbd1fd998f169cc30607c30146da7da2801dcb..4fe2841ac587c0dab271f4d0c06a002a64e5cf0d 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -54,21 +54,21 @@ if __name__ == "__main__": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["train.server_num"] = 1 - cluster_envs["train.worker_num"] = 1 - cluster_envs["train.start_port"] = 36001 - cluster_envs["train.log_dir"] = "logs" - cluster_envs["train.trainer"] = "SingleTraining" + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" + cluster_envs["train.trainer"] = "ClusterTraining" local_cluster_engine(cluster_envs, args.model) elif args.engine == "LocalMPI": print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["train.server_num"] = 1 - cluster_envs["train.worker_num"] = 1 - cluster_envs["train.start_port"] = 36001 - cluster_envs["train.log_dir"] = "logs" + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" cluster_envs["train.trainer"] = "CtrTraining" local_mpi_engine(cluster_envs, args.model)