From 4a73a159b96dc47a9736641562a3b2e4139c1c91 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 12 Aug 2020 19:30:45 +0800 Subject: [PATCH] add online training --- core/factory.py | 3 ++- run.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/factory.py b/core/factory.py index 4c08f1f6..64023a0a 100755 --- a/core/factory.py +++ b/core/factory.py @@ -37,7 +37,8 @@ def trainer_registry(): trainer_abs, "tdm_single_trainer.py") trainers["TDMClusterTrainer"] = os.path.join( trainer_abs, "tdm_cluster_trainer.py") - + trainers["OnlineLearningTrainer"] = os.path.join( + trainer_abs, "online_learning_trainer.py") trainer_registry() diff --git a/run.py b/run.py index 56999935..944d9930 100755 --- a/run.py +++ b/run.py @@ -39,6 +39,7 @@ def engine_registry(): engines["TRANSPILER"]["SINGLE"] = single_engine engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine engines["TRANSPILER"]["CLUSTER"] = cluster_engine + engines["TRANSPILER"]["ONLINE_LEARNING"] = online_learning engines["PSLIB"]["SINGLE"] = local_mpi_engine engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine @@ -125,6 +126,19 @@ def single_engine(args): trainer = TrainerFactory.create(args.model) return trainer +def online_learning(args): + trainer = "OnlineLearningTrainer" + single_envs = {} + single_envs["train.trainer.trainer"] = trainer + single_envs["train.trainer.threads"] = "2" + single_envs["train.trainer.engine"] = "online_learning" + single_envs["train.trainer.platform"] = envs.get_platform() + print("use {} engine to run model: {}".format(trainer, args.model)) + + set_runtime_envs(single_envs, args.model) + trainer = TrainerFactory.create(args.model) + return trainer + def cluster_engine(args): def update_workspace(cluster_envs): -- GitLab