From f18eea2108d469681d66c048d0e76998d3008639 Mon Sep 17 00:00:00 2001 From: tangwei Date: Tue, 12 May 2020 14:23:08 +0800 Subject: [PATCH] add mpi cluster --- .../engine/{mpi_cluster => cluster}/__init__.py | 0 .../core/engine/{mpi_cluster => cluster}/cluster.py | 2 +- .../core/engine/{mpi_cluster => cluster}/job.sh | 0 .../core/engine/{mpi_cluster => cluster}/submit.sh | 0 fleet_rec/run.py | 13 +++++++++---- 5 files changed, 10 insertions(+), 5 deletions(-) rename fleet_rec/core/engine/{mpi_cluster => cluster}/__init__.py (100%) rename fleet_rec/core/engine/{mpi_cluster => cluster}/cluster.py (97%) rename fleet_rec/core/engine/{mpi_cluster => cluster}/job.sh (100%) rename fleet_rec/core/engine/{mpi_cluster => cluster}/submit.sh (100%) diff --git a/fleet_rec/core/engine/mpi_cluster/__init__.py b/fleet_rec/core/engine/cluster/__init__.py similarity index 100% rename from fleet_rec/core/engine/mpi_cluster/__init__.py rename to fleet_rec/core/engine/cluster/__init__.py diff --git a/fleet_rec/core/engine/mpi_cluster/cluster.py b/fleet_rec/core/engine/cluster/cluster.py similarity index 97% rename from fleet_rec/core/engine/mpi_cluster/cluster.py rename to fleet_rec/core/engine/cluster/cluster.py index b8756aeb..abb3cf83 100644 --- a/fleet_rec/core/engine/mpi_cluster/cluster.py +++ b/fleet_rec/core/engine/cluster/cluster.py @@ -23,7 +23,7 @@ import copy from fleetrec.core.engine.engine import Engine -class QSubClusterEngine(Engine): +class ClusterEngine(Engine): def __init_impl__(self): abs_dir = os.path.dirname(os.path.abspath(__file__)) self.submit_script = os.path.join(abs_dir, "submit.sh") diff --git a/fleet_rec/core/engine/mpi_cluster/job.sh b/fleet_rec/core/engine/cluster/job.sh similarity index 100% rename from fleet_rec/core/engine/mpi_cluster/job.sh rename to fleet_rec/core/engine/cluster/job.sh diff --git a/fleet_rec/core/engine/mpi_cluster/submit.sh b/fleet_rec/core/engine/cluster/submit.sh similarity index 100% rename from fleet_rec/core/engine/mpi_cluster/submit.sh rename to fleet_rec/core/engine/cluster/submit.sh diff --git a/fleet_rec/run.py b/fleet_rec/run.py index 2dda11e7..e9b91a32 100755 --- a/fleet_rec/run.py +++ b/fleet_rec/run.py @@ -108,6 +108,7 @@ def single_engine(args): def cluster_engine(args): + from fleetrec.core.engine.cluster.cluster import ClusterEngine trainer = get_trainer_prefix(args) + "ClusterTrainer" cluster_envs = {} cluster_envs["train.trainer.trainer"] = trainer @@ -117,8 +118,8 @@ def cluster_engine(args): print("launch {} engine with cluster to run model: {}".format(trainer, args.model)) set_runtime_envs(cluster_envs, args.model) - trainer = TrainerFactory.create(args.model) - return trainer + launch = LocalClusterEngine(cluster_envs, args.model) + return launch def cluster_mpi_engine(args): @@ -201,8 +202,10 @@ if __name__ == "__main__": parser.add_argument("-e", "--engine", type=str, choices=["single", "local_cluster", "cluster", "tdm_single", "tdm_local_cluster", "tdm_cluster"]) - parser.add_argument("-d", "--device", type=str, - choices=["cpu", "gpu"], default="cpu") + + parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu") + parser.add_argument("-b", "--backend", type=str, default=None) + parser.add_argument("-r", "--role", type=str, choices=["master", "worker"], default="master") abs_dir = os.path.dirname(os.path.abspath(__file__)) envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) @@ -210,6 +213,8 @@ if __name__ == "__main__": args = parser.parse_args() args.engine = args.engine.upper() args.device = args.device.upper() + args.role = args.role.upper() + model_name = args.model.split('.')[-1] args.model = get_abs_model(args.model) engine_registry() -- GitLab