提交 95ea5e58 编写于 作者: T tangwei

add userdine engine

上级 8a2a06e4
import argparse import argparse
import os import os
import sys
import yaml
from paddle.fluid.incubate.fleet.parameter_server import version
from fleetrec.core.factory import TrainerFactory from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
...@@ -25,20 +29,42 @@ def local_cluster_engine(cluster_envs, model_yaml): ...@@ -25,20 +29,42 @@ def local_cluster_engine(cluster_envs, model_yaml):
launch.run() launch.run()
def local_mpi_engine(cluster_envs, model_yaml): def local_mpi_engine(model_yaml):
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "CtrTraining"
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
print("coming soon") print("coming soon")
def yaml_engine(engine_yaml, model_yaml): def yaml_engine(engine_yaml, model_yaml):
print("coming soon") with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_global_envs(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer = trainer_class(model_yaml)
trainer.run()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("--model", type=str) parser.add_argument("--model", type=str)
parser.add_argument("--engine", type=str) parser.add_argument("--engine", type=str)
parser.add_argument("--engine_extras", type=str)
args = parser.parse_args() args = parser.parse_args()
...@@ -46,35 +72,34 @@ if __name__ == "__main__": ...@@ -46,35 +72,34 @@ if __name__ == "__main__":
raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model)) raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model))
if args.engine.upper() == "SINGLE": if args.engine.upper() == "SINGLE":
print("use SingleTraining to run model: {}".format(args.model)) if version.is_transpiler():
single_envs = {} print("use SingleTraining to run model: {}".format(args.model))
single_envs["train.trainer"] = "SingleTraining" single_envs = {"train.trainer": "SingleTraining"}
single_engine(single_envs, args.model)
single_engine(single_envs, args.model) else:
local_mpi_engine(args.model)
elif args.engine.upper() == "LOCAL_CLUSTER": elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
if version.is_transpiler():
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining" cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async" cluster_envs["train.strategy.mode"] = "async"
local_cluster_engine(cluster_envs, args.model) local_cluster_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_MPI": else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) local_mpi_engine(args.model)
elif args.engine.upper() == "CLUSTER":
cluster_envs = {} print("launch ClusterTraining with cluster to run model: {}".format(args.model))
cluster_envs["server_num"] = 1 run(args.model)
cluster_envs["worker_num"] = 1 elif args.engine.upper() == "USER_DEFINE":
cluster_envs["start_port"] = 36001 engine_file = args.engine_extras
cluster_envs["log_dir"] = "logs" if not os.path.exists(engine_file) or not os.path.isfile(engine_file):
cluster_envs["train.trainer"] = "CtrTraining" raise ValueError(
"argument engine: user_define error, must specify a existed yaml file".format(args.engine_file))
local_mpi_engine(cluster_envs, args.model) yaml_engine(engine_file, args.model)
else: else:
if not os.path.exists(args.engine) or not os.path.isfile(args.engine): raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")
raise ValueError("argument engine: {} error, must specify a existed yaml file".format(args.engine))
yaml_engine(args.engine, args.model)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册