提交 fbbc7134 编写于 作者: T tangwei

fix import

上级 fd5e7f94
......@@ -17,10 +17,6 @@ import sys
import yaml
from fleetrec.core.trainers.single_trainer import SingleTrainer
from fleetrec.core.trainers.cluster_trainer import ClusterTrainer
from fleetrec.core.trainers.ctr_trainer import CtrPaddleTrainer
from fleetrec.core.utils import envs
......@@ -38,10 +34,13 @@ class TrainerFactory(object):
train_mode = envs.get_runtime_envion("train.trainer")
if train_mode == "SingleTraining":
from fleetrec.core.trainers.single_trainer import SingleTrainer
trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining":
from fleetrec.core.trainers.cluster_trainer import ClusterTrainer
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTraining":
from fleetrec.core.trainers.ctr_modul_trainer import CtrPaddleTrainer
trainer = CtrPaddleTrainer(config)
elif train_mode == "UserDefineTraining":
train_location = envs.get_global_env("train.location")
......
......@@ -8,8 +8,6 @@ from paddle.fluid.incubate.fleet.parameter_server import version
from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.utils import util
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
def run(model_yaml):
......@@ -25,22 +23,16 @@ def single_engine(single_envs, model_yaml):
def local_cluster_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalClusterEngine(cluster_envs, model_yaml)
launch.run()
def local_mpi_engine(model_yaml):
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
def local_mpi_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
......@@ -81,7 +73,14 @@ if __name__ == "__main__":
single_envs = {"train.trainer": "SingleTraining"}
single_engine(single_envs, args.model)
else:
local_mpi_engine(args.model)
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
if version.is_transpiler():
......@@ -95,7 +94,14 @@ if __name__ == "__main__":
local_cluster_engine(cluster_envs, args.model)
else:
local_mpi_engine(args.model)
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
run(args.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册