diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 3aa3197deafc38c95f78a55c10961a1b1c3ee0dc..ef882e9ee21a79340c28266701dc6d8834862d82 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -17,10 +17,6 @@ import sys import yaml -from fleetrec.core.trainers.single_trainer import SingleTrainer -from fleetrec.core.trainers.cluster_trainer import ClusterTrainer -from fleetrec.core.trainers.ctr_trainer import CtrPaddleTrainer - from fleetrec.core.utils import envs @@ -38,10 +34,13 @@ class TrainerFactory(object): train_mode = envs.get_runtime_envion("train.trainer") if train_mode == "SingleTraining": + from fleetrec.core.trainers.single_trainer import SingleTrainer trainer = SingleTrainer(yaml_path) elif train_mode == "ClusterTraining": + from fleetrec.core.trainers.cluster_trainer import ClusterTrainer trainer = ClusterTrainer(yaml_path) elif train_mode == "CtrTraining": + from fleetrec.core.trainers.ctr_modul_trainer import CtrPaddleTrainer trainer = CtrPaddleTrainer(config) elif train_mode == "UserDefineTraining": train_location = envs.get_global_env("train.location") diff --git a/fleetrec/core/trainers/ctr_trainer.py b/fleetrec/core/trainers/ctr_modul_trainer.py similarity index 100% rename from fleetrec/core/trainers/ctr_trainer.py rename to fleetrec/core/trainers/ctr_modul_trainer.py diff --git a/fleetrec/run.py b/fleetrec/run.py index 5b5fa783dbc55c641cc0201c82faa2d0b5c8325e..a3ea24f8a5d51abb1edaab547b48184d6e61648a 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -8,8 +8,6 @@ from paddle.fluid.incubate.fleet.parameter_server import version from fleetrec.core.factory import TrainerFactory from fleetrec.core.utils import envs from fleetrec.core.utils import util -from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine -from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine def run(model_yaml): @@ -25,22 +23,16 @@ def single_engine(single_envs, model_yaml): def local_cluster_engine(cluster_envs, model_yaml): + from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine + print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) launch = LocalClusterEngine(cluster_envs, model_yaml) launch.run() -def local_mpi_engine(model_yaml): - print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) - - mpi_path = util.run_which("mpirun") - - if not mpi_path: - raise RuntimeError("can not find mpirun, please check environment") - - cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} - +def local_mpi_engine(cluster_envs, model_yaml): + from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) @@ -81,7 +73,14 @@ if __name__ == "__main__": single_envs = {"train.trainer": "SingleTraining"} single_engine(single_envs, args.model) else: - local_mpi_engine(args.model) + print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) + + mpi_path = util.run_which("mpirun") + if not mpi_path: + raise RuntimeError("can not find mpirun, please check environment") + + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "LOCAL_CLUSTER": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) if version.is_transpiler(): @@ -95,7 +94,14 @@ if __name__ == "__main__": local_cluster_engine(cluster_envs, args.model) else: - local_mpi_engine(args.model) + print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) + + mpi_path = util.run_which("mpirun") + if not mpi_path: + raise RuntimeError("can not find mpirun, please check environment") + + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "CLUSTER": print("launch ClusterTraining with cluster to run model: {}".format(args.model)) run(args.model)