diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index c2e02bf5e331b067639e58079bfae3091331f017..9545364e7653942ca0d1100a50bd5aa2c2829141 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -19,6 +19,19 @@ import yaml from fleetrec.core.utils import envs +trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers") +trainers = {} + + +def trainer_registry(): + trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py") + trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py") + trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py") + trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py") + + +trainer_registry() + class TrainerFactory(object): def __init__(self): @@ -28,26 +41,21 @@ class TrainerFactory(object): def _build_trainer(yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) - train_mode = envs.get_training_mode() - - if train_mode == "SingleTraining": - from fleetrec.core.trainers.single_trainer import SingleTrainer - trainer = SingleTrainer(yaml_path) - elif train_mode == "ClusterTraining": - from fleetrec.core.trainers.cluster_trainer import ClusterTrainer - trainer = ClusterTrainer(yaml_path) - elif train_mode == "CtrTraining": - from fleetrec.core.trainers.ctr_coding_trainer import CtrPaddleTrainer - trainer = CtrPaddleTrainer(yaml_path) - elif train_mode == "UserDefineTraining": - train_location = envs.get_global_env("train.location") - train_dirname = os.path.dirname(train_location) - base_name = os.path.splitext(os.path.basename(train_location))[0] - sys.path.append(train_dirname) - trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer") - trainer = trainer_class(yaml_path) - else: - raise ValueError("trainer only support SingleTraining/ClusterTraining") + train_mode = envs.get_trainer() + trainer_abs = trainers.get(train_mode, None) + + if trainer_abs is None: + if not os.path.exists(train_mode) or os.path.isfile(train_mode): + raise ValueError("trainer {} can not be recognized") + trainer_abs = train_mode + train_mode = "UserDefineTrainer" + + train_location = envs.get_global_env("train.location") + train_dirname = os.path.dirname(trainer_abs) + base_name = os.path.splitext(os.path.basename(train_location))[0] + sys.path.append(train_dirname) + trainer_class = envs.lazy_instance(base_name, train_mode) + trainer = trainer_class(yaml_path) return trainer @staticmethod diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index fc5228f01e9caaac2daf494f2bfde417393570c0..aed9e0a590456b1ab5dc67655cd2a642521f73fe 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -29,11 +29,8 @@ def get_runtime_envion(key): return os.getenv(key, None) -def get_training_mode(): - train_mode = get_global_env("train.trainer") - - if train_mode is None: - train_mode = get_runtime_envion("train.trainer") +def get_trainer(): + train_mode = get_runtime_envion("trainer.trainer") return train_mode diff --git a/fleetrec/models/ctr_dnn/model.py b/fleetrec/models/ctr_dnn/model.py index 8a1e2bbe59751f94f579a24abee91c65154f62af..2890efe9d81aa81e38245a174286828130fa6695 100644 --- a/fleetrec/models/ctr_dnn/model.py +++ b/fleetrec/models/ctr_dnn/model.py @@ -60,12 +60,12 @@ class Model(ModelBase): self._data_var.append(self.label_input) def net(self): - train_mode = envs.get_training_mode() + trainer = envs.get_trainer() - is_distributed = True if train_mode == "CtrTraining" else False + is_distributed = True if trainer == "CtrTrainer" else False sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace) sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace) - sparse_feature_dim = 9 if train_mode == "CtrTraining" else sparse_feature_dim + sparse_feature_dim = 9 if trainer == "CtrTrainer" else sparse_feature_dim def embedding_layer(input): emb = fluid.layers.embedding( diff --git a/fleetrec/run.py b/fleetrec/run.py index f66e39ff7159d58422a7cfc56a19516391b4c208..f1f5d1c5b8d1ac6c712c26aa4b57ff8491c5029e 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -1,6 +1,5 @@ import argparse import os -import sys import yaml from paddle.fluid.incubate.fleet.parameter_server import version @@ -10,6 +9,19 @@ from fleetrec.core.utils import envs from fleetrec.core.utils import util engines = {"TRAINSPILER": {}, "PSLIB": {}} +clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] + + +def set_runtime_envs(cluster_envs, engine_yaml): + if engine_yaml is not None: + with open(engine_yaml, 'r') as rb: + _envs = yaml.load(rb.read(), Loader=yaml.FullLoader) + + if cluster_envs is None: + cluster_envs = {} + cluster_envs.update(_envs) + envs.set_runtime_envions(cluster_envs) + print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) def engine_registry(): @@ -34,35 +46,38 @@ def get_engine(engine): def single_engine(args): - print("use SingleTraining to run model: {}".format(args.model)) - single_envs = {"train.trainer": "SingleTraining"} + print("use single engine to run model: {}".format(args.model)) + single_envs = {"trainer.trainer": "SingleTraining"} - print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) - envs.set_runtime_envions(single_envs) + set_runtime_envs(single_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) return trainer def cluster_engine(args): - print("launch ClusterTraining with cluster to run model: {}".format(args.model)) + print("launch cluster engine with cluster to run model: {}".format(args.model)) + + cluster_envs = {"trainer.trainer": "ClusterTraining"} + set_runtime_envs(cluster_envs, args.engine_extras) - cluster_envs = {"train.trainer": "ClusterTraining"} envs.set_runtime_envions(cluster_envs) trainer = TrainerFactory.create(args.model) return trainer def cluster_mpi_engine(args): - print("launch ClusterTraining with cluster to run model: {}".format(args.model)) + print("launch cluster engine with cluster to run model: {}".format(args.model)) + + cluster_envs = {"trainer.trainer": "CtrTraining"} + set_runtime_envs(cluster_envs, args.engine_extras) - cluster_envs = {"train.trainer": "CtrTraining"} - envs.set_runtime_envions(cluster_envs) trainer = TrainerFactory.create(args.model) return trainer def local_cluster_engine(args): + print("launch cluster engine with cluster to run model: {}".format(args.model)) from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine cluster_envs = {} @@ -70,17 +85,17 @@ def local_cluster_engine(args): cluster_envs["worker_num"] = 1 cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" - cluster_envs["train.trainer"] = "ClusterTraining" - cluster_envs["train.strategy.mode"] = "async" + cluster_envs["trainer.trainer"] = "ClusterTraining" + cluster_envs["trainer.strategy.mode"] = "async" - print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) - envs.set_runtime_envions(cluster_envs) + set_runtime_envs(cluster_envs, args.engine_extras) launch = LocalClusterEngine(cluster_envs, args.model) return launch def local_mpi_engine(args): + print("launch cluster engine with cluster to run model: {}".format(args.model)) from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) @@ -89,10 +104,8 @@ def local_mpi_engine(args): if not mpi: raise RuntimeError("can not find mpirun, please check environment") - cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"} - - print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) - envs.set_runtime_envions(cluster_envs) + cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"} + set_runtime_envs(cluster_envs, args.engine_extras) launch = LocalMPIEngine(cluster_envs, args.model) return launch @@ -118,13 +131,21 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') parser.add_argument("-m", "--model", type=str) parser.add_argument("-e", "--engine", type=str) - parser.add_argument("-ex", "--engine_extras", type=str) + parser.add_argument("-ex", "--engine_extras", default=None, type=str) args = parser.parse_args() if not os.path.exists(args.model) or not os.path.isfile(args.model): raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model)) + if args.engine.upper() not in clusters: + raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters)) + + if args.engine_extras is not None: + if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras): + raise ValueError( + "argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras)) + which_engine = get_engine(args.engine) engine = which_engine(args) engine.run()