提交 f18eea21 编写于 作者: T tangwei

add mpi cluster

上级 e9011c05
...@@ -23,7 +23,7 @@ import copy ...@@ -23,7 +23,7 @@ import copy
from fleetrec.core.engine.engine import Engine from fleetrec.core.engine.engine import Engine
class QSubClusterEngine(Engine): class ClusterEngine(Engine):
def __init_impl__(self): def __init_impl__(self):
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
self.submit_script = os.path.join(abs_dir, "submit.sh") self.submit_script = os.path.join(abs_dir, "submit.sh")
......
...@@ -108,6 +108,7 @@ def single_engine(args): ...@@ -108,6 +108,7 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
from fleetrec.core.engine.cluster.cluster import ClusterEngine
trainer = get_trainer_prefix(args) + "ClusterTrainer" trainer = get_trainer_prefix(args) + "ClusterTrainer"
cluster_envs = {} cluster_envs = {}
cluster_envs["train.trainer.trainer"] = trainer cluster_envs["train.trainer.trainer"] = trainer
...@@ -117,8 +118,8 @@ def cluster_engine(args): ...@@ -117,8 +118,8 @@ def cluster_engine(args):
print("launch {} engine with cluster to run model: {}".format(trainer, args.model)) print("launch {} engine with cluster to run model: {}".format(trainer, args.model))
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
trainer = TrainerFactory.create(args.model) launch = LocalClusterEngine(cluster_envs, args.model)
return trainer return launch
def cluster_mpi_engine(args): def cluster_mpi_engine(args):
...@@ -201,8 +202,10 @@ if __name__ == "__main__": ...@@ -201,8 +202,10 @@ if __name__ == "__main__":
parser.add_argument("-e", "--engine", type=str, parser.add_argument("-e", "--engine", type=str,
choices=["single", "local_cluster", "cluster", choices=["single", "local_cluster", "cluster",
"tdm_single", "tdm_local_cluster", "tdm_cluster"]) "tdm_single", "tdm_local_cluster", "tdm_cluster"])
parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu") parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu")
parser.add_argument("-b", "--backend", type=str, default=None)
parser.add_argument("-r", "--role", type=str, choices=["master", "worker"], default="master")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
...@@ -210,6 +213,8 @@ if __name__ == "__main__": ...@@ -210,6 +213,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
args.engine = args.engine.upper() args.engine = args.engine.upper()
args.device = args.device.upper() args.device = args.device.upper()
args.role = args.role.upper()
model_name = args.model.split('.')[-1] model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
engine_registry() engine_registry()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册