diff --git a/fleetrec/run.py b/fleetrec/run.py index a4b51990fac298cd06aa38af4e6096b4d89e9f5b..d6b083205f472ba94bab4370f017e43dbaae48a7 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -1,5 +1,9 @@ import argparse import os +import sys + +import yaml +from paddle.fluid.incubate.fleet.parameter_server import version from fleetrec.core.factory import TrainerFactory from fleetrec.core.utils import envs @@ -25,20 +29,42 @@ def local_cluster_engine(cluster_envs, model_yaml): launch.run() -def local_mpi_engine(cluster_envs, model_yaml): +def local_mpi_engine(model_yaml): + print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) + + cluster_envs = {} + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" + cluster_envs["train.trainer"] = "CtrTraining" + print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) print("coming soon") def yaml_engine(engine_yaml, model_yaml): - print("coming soon") + with open(engine_yaml, 'r') as rb: + _config = yaml.load(rb.read(), Loader=yaml.FullLoader) + assert _config is not None + + envs.set_global_envs(_config) + + train_location = envs.get_global_env("engine.file") + train_dirname = os.path.dirname(train_location) + base_name = os.path.splitext(os.path.basename(train_location))[0] + sys.path.append(train_dirname) + trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer") + trainer = trainer_class(model_yaml) + trainer.run() if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') parser.add_argument("--model", type=str) parser.add_argument("--engine", type=str) + parser.add_argument("--engine_extras", type=str) args = parser.parse_args() @@ -46,35 +72,34 @@ if __name__ == "__main__": raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model)) if args.engine.upper() == "SINGLE": - print("use SingleTraining to run model: {}".format(args.model)) - single_envs = {} - single_envs["train.trainer"] = "SingleTraining" - - single_engine(single_envs, args.model) + if version.is_transpiler(): + print("use SingleTraining to run model: {}".format(args.model)) + single_envs = {"train.trainer": "SingleTraining"} + single_engine(single_envs, args.model) + else: + local_mpi_engine(args.model) elif args.engine.upper() == "LOCAL_CLUSTER": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) - - cluster_envs = {} - cluster_envs["server_num"] = 1 - cluster_envs["worker_num"] = 1 - cluster_envs["start_port"] = 36001 - cluster_envs["log_dir"] = "logs" - cluster_envs["train.trainer"] = "ClusterTraining" - cluster_envs["train.strategy.mode"] = "async" - - local_cluster_engine(cluster_envs, args.model) - elif args.engine.upper() == "LOCAL_MPI": - print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) - - cluster_envs = {} - cluster_envs["server_num"] = 1 - cluster_envs["worker_num"] = 1 - cluster_envs["start_port"] = 36001 - cluster_envs["log_dir"] = "logs" - cluster_envs["train.trainer"] = "CtrTraining" - - local_mpi_engine(cluster_envs, args.model) + if version.is_transpiler(): + cluster_envs = {} + cluster_envs["server_num"] = 1 + cluster_envs["worker_num"] = 1 + cluster_envs["start_port"] = 36001 + cluster_envs["log_dir"] = "logs" + cluster_envs["train.trainer"] = "ClusterTraining" + cluster_envs["train.strategy.mode"] = "async" + + local_cluster_engine(cluster_envs, args.model) + else: + local_mpi_engine(args.model) + elif args.engine.upper() == "CLUSTER": + print("launch ClusterTraining with cluster to run model: {}".format(args.model)) + run(args.model) + elif args.engine.upper() == "USER_DEFINE": + engine_file = args.engine_extras + if not os.path.exists(engine_file) or not os.path.isfile(engine_file): + raise ValueError( + "argument engine: user_define error, must specify a existed yaml file".format(args.engine_file)) + yaml_engine(engine_file, args.model) else: - if not os.path.exists(args.engine) or not os.path.isfile(args.engine): - raise ValueError("argument engine: {} error, must specify a existed yaml file".format(args.engine)) - yaml_engine(args.engine, args.model) + raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")