From fbbc7134cc81d1763e819a2b7648b4229e2a6e8d Mon Sep 17 00:00:00 2001 From: tangwei Date: Thu, 16 Apr 2020 13:52:45 +0800 Subject: [PATCH] fix import --- fleetrec/core/factory.py | 7 ++-- .../{ctr_trainer.py => ctr_modul_trainer.py} | 0 fleetrec/run.py | 34 +++++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) rename fleetrec/core/trainers/{ctr_trainer.py => ctr_modul_trainer.py} (100%) diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 3aa3197d..ef882e9e 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -17,10 +17,6 @@ import sys import yaml -from fleetrec.core.trainers.single_trainer import SingleTrainer -from fleetrec.core.trainers.cluster_trainer import ClusterTrainer -from fleetrec.core.trainers.ctr_trainer import CtrPaddleTrainer - from fleetrec.core.utils import envs @@ -38,10 +34,13 @@ class TrainerFactory(object): train_mode = envs.get_runtime_envion("train.trainer") if train_mode == "SingleTraining": + from fleetrec.core.trainers.single_trainer import SingleTrainer trainer = SingleTrainer(yaml_path) elif train_mode == "ClusterTraining": + from fleetrec.core.trainers.cluster_trainer import ClusterTrainer trainer = ClusterTrainer(yaml_path) elif train_mode == "CtrTraining": + from fleetrec.core.trainers.ctr_modul_trainer import CtrPaddleTrainer trainer = CtrPaddleTrainer(config) elif train_mode == "UserDefineTraining": train_location = envs.get_global_env("train.location") diff --git a/fleetrec/core/trainers/ctr_trainer.py b/fleetrec/core/trainers/ctr_modul_trainer.py similarity index 100% rename from fleetrec/core/trainers/ctr_trainer.py rename to fleetrec/core/trainers/ctr_modul_trainer.py diff --git a/fleetrec/run.py b/fleetrec/run.py index 5b5fa783..a3ea24f8 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -8,8 +8,6 @@ from paddle.fluid.incubate.fleet.parameter_server import version from fleetrec.core.factory import TrainerFactory from fleetrec.core.utils import envs from fleetrec.core.utils import util -from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine -from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine def run(model_yaml): @@ -25,22 +23,16 @@ def single_engine(single_envs, model_yaml): def local_cluster_engine(cluster_envs, model_yaml): + from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine + print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) launch = LocalClusterEngine(cluster_envs, model_yaml) launch.run() -def local_mpi_engine(model_yaml): - print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) - - mpi_path = util.run_which("mpirun") - - if not mpi_path: - raise RuntimeError("can not find mpirun, please check environment") - - cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} - +def local_mpi_engine(cluster_envs, model_yaml): + from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) @@ -81,7 +73,14 @@ if __name__ == "__main__": single_envs = {"train.trainer": "SingleTraining"} single_engine(single_envs, args.model) else: - local_mpi_engine(args.model) + print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) + + mpi_path = util.run_which("mpirun") + if not mpi_path: + raise RuntimeError("can not find mpirun, please check environment") + + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "LOCAL_CLUSTER": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) if version.is_transpiler(): @@ -95,7 +94,14 @@ if __name__ == "__main__": local_cluster_engine(cluster_envs, args.model) else: - local_mpi_engine(args.model) + print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) + + mpi_path = util.run_which("mpirun") + if not mpi_path: + raise RuntimeError("can not find mpirun, please check environment") + + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "CLUSTER": print("launch ClusterTraining with cluster to run model: {}".format(args.model)) run(args.model) -- GitLab