From e7986cb451ed2c07726e197f294e863b3900d6d5 Mon Sep 17 00:00:00 2001 From: tangwei Date: Mon, 20 Apr 2020 15:48:27 +0800 Subject: [PATCH] fix readme --- fleetrec/core/factory.py | 2 +- fleetrec/run.py | 45 ++++++++++++++-------------------------- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 9545364e..332381e4 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -46,7 +46,7 @@ class TrainerFactory(object): if trainer_abs is None: if not os.path.exists(train_mode) or os.path.isfile(train_mode): - raise ValueError("trainer {} can not be recognized") + raise ValueError("trainer {} can not be recognized".format(train_mode)) trainer_abs = train_mode train_mode = "UserDefineTrainer" diff --git a/fleetrec/run.py b/fleetrec/run.py index f1f5d1c5..aad2ece3 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -16,6 +16,8 @@ def set_runtime_envs(cluster_envs, engine_yaml): if engine_yaml is not None: with open(engine_yaml, 'r') as rb: _envs = yaml.load(rb.read(), Loader=yaml.FullLoader) + else: + _envs = {} if cluster_envs is None: cluster_envs = {} @@ -24,15 +26,6 @@ def set_runtime_envs(cluster_envs, engine_yaml): print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) -def engine_registry(): - engines["TRAINSPILER"]["SINGLE"] = single_engine - engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine - engines["TRAINSPILER"]["CLUSTER"] = cluster_engine - engines["PSLIB"]["SINGLE"] = local_mpi_engine - engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine - engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine - - def get_engine(engine): engine = engine.upper() if version.is_transpiler(): @@ -47,7 +40,7 @@ def get_engine(engine): def single_engine(args): print("use single engine to run model: {}".format(args.model)) - single_envs = {"trainer.trainer": "SingleTraining"} + single_envs = {"trainer.trainer": "SingleTrainer"} set_runtime_envs(single_envs, args.engine_extras) @@ -58,7 +51,7 @@ def single_engine(args): def cluster_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "ClusterTraining"} + cluster_envs = {"trainer.trainer": "ClusterTrainer"} set_runtime_envs(cluster_envs, args.engine_extras) envs.set_runtime_envions(cluster_envs) @@ -69,7 +62,7 @@ def cluster_engine(args): def cluster_mpi_engine(args): print("launch cluster engine with cluster to run model: {}".format(args.model)) - cluster_envs = {"trainer.trainer": "CtrTraining"} + cluster_envs = {"trainer.trainer": "CtrCodingTrainer"} set_runtime_envs(cluster_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) @@ -85,7 +78,7 @@ def local_cluster_engine(args): cluster_envs["worker_num"] = 1 cluster_envs["start_port"] = 36001 cluster_envs["log_dir"] = "logs" - cluster_envs["trainer.trainer"] = "ClusterTraining" + cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.strategy.mode"] = "async" set_runtime_envs(cluster_envs, args.engine_extras) @@ -104,28 +97,22 @@ def local_mpi_engine(args): if not mpi: raise RuntimeError("can not find mpirun, please check environment") - cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"} + cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"} set_runtime_envs(cluster_envs, args.engine_extras) launch = LocalMPIEngine(cluster_envs, args.model) return launch -# -# def yaml_engine(engine_yaml, model_yaml): -# with open(engine_yaml, 'r') as rb: -# _config = yaml.load(rb.read(), Loader=yaml.FullLoader) -# assert _config is not None -# -# envs.set_global_envs(_config) -# -# train_location = envs.get_global_env("engine.file") -# train_dirname = os.path.dirname(train_location) -# base_name = os.path.splitext(os.path.basename(train_location))[0] -# sys.path.append(train_dirname) -# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining") -# trainer = trainer_class(model_yaml) -# return trainer +def engine_registry(): + engines["TRAINSPILER"]["SINGLE"] = single_engine + engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine + engines["TRAINSPILER"]["CLUSTER"] = cluster_engine + engines["PSLIB"]["SINGLE"] = local_mpi_engine + engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine + engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine + +engine_registry() if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') -- GitLab