diff --git a/core/factory.py b/core/factory.py index 4c08f1f6bbd70cc65011e8430e3acf039d7b6c8f..64023a0a9b24ad95738f35542678616135590d52 100755 --- a/core/factory.py +++ b/core/factory.py @@ -37,7 +37,8 @@ def trainer_registry(): trainer_abs, "tdm_single_trainer.py") trainers["TDMClusterTrainer"] = os.path.join( trainer_abs, "tdm_cluster_trainer.py") - + trainers["OnlineLearningTrainer"] = os.path.join( + trainer_abs, "online_learning_trainer.py") trainer_registry() diff --git a/run.py b/run.py index 56999935f21bc1de2b2bc7b4a080da023559174a..944d99304f15303af18f22471dcc1a7eb5062645 100755 --- a/run.py +++ b/run.py @@ -39,6 +39,7 @@ def engine_registry(): engines["TRANSPILER"]["SINGLE"] = single_engine engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine engines["TRANSPILER"]["CLUSTER"] = cluster_engine + engines["TRANSPILER"]["ONLINE_LEARNING"] = online_learning engines["PSLIB"]["SINGLE"] = local_mpi_engine engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine @@ -125,6 +126,19 @@ def single_engine(args): trainer = TrainerFactory.create(args.model) return trainer +def online_learning(args): + trainer = "OnlineLearningTrainer" + single_envs = {} + single_envs["train.trainer.trainer"] = trainer + single_envs["train.trainer.threads"] = "2" + single_envs["train.trainer.engine"] = "online_learning" + single_envs["train.trainer.platform"] = envs.get_platform() + print("use {} engine to run model: {}".format(trainer, args.model)) + + set_runtime_envs(single_envs, args.model) + trainer = TrainerFactory.create(args.model) + return trainer + def cluster_engine(args): def update_workspace(cluster_envs):