diff --git a/core/factory.py b/core/factory.py index 9430c88283800e69db7043aa141b6f735212c79f..95e0e7778141ad76d1166205213bccdaae67aff7 100755 --- a/core/factory.py +++ b/core/factory.py @@ -22,6 +22,19 @@ trainers = {} def trainer_registry(): + trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py") + trainers["ClusterTrainer"] = os.path.join(trainer_abs, + "cluster_trainer.py") + trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, + "ctr_coding_trainer.py") + trainers["CtrModulTrainer"] = os.path.join(trainer_abs, + "ctr_modul_trainer.py") + trainers["TDMSingleTrainer"] = os.path.join(trainer_abs, + "tdm_single_trainer.py") + trainers["TDMClusterTrainer"] = os.path.join(trainer_abs, + "tdm_cluster_trainer.py") + trainers["OnlineLearningTrainer"] = os.path.join( + trainer_abs, "online_learning_trainer.py") # Definition of procedure execution process trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py") diff --git a/run.py b/run.py index 6340adfc1c6026d7c67f5576ba8d0230055ec19d..c916ecd0ab3b0efe71ef86a4bf1d7f357aa9d563 100755 --- a/run.py +++ b/run.py @@ -16,7 +16,6 @@ import os import subprocess import sys import argparse -import tempfile import warnings import copy @@ -39,6 +38,7 @@ def engine_registry(): engines["TRANSPILER"]["INFER"] = single_infer_engine engines["TRANSPILER"]["LOCAL_CLUSTER_TRAIN"] = local_cluster_engine engines["TRANSPILER"]["CLUSTER_TRAIN"] = cluster_engine + engines["TRANSPILER"]["ONLINE_LEARNING"] = online_learning engines["PSLIB"]["TRAIN"] = local_mpi_engine engines["PSLIB"]["LOCAL_CLUSTER_TRAIN"] = local_mpi_engine engines["PSLIB"]["CLUSTER_TRAIN"] = cluster_mpi_engine @@ -259,6 +259,20 @@ def single_infer_engine(args): return trainer +def online_learning(args): + trainer = "OnlineLearningTrainer" + single_envs = {} + single_envs["train.trainer.trainer"] = trainer + single_envs["train.trainer.threads"] = "2" + single_envs["train.trainer.engine"] = "online_learning" + single_envs["train.trainer.platform"] = envs.get_platform() + print("use {} engine to run model: {}".format(trainer, args.model)) + + set_runtime_envs(single_envs, args.model) + trainer = TrainerFactory.create(args.model) + return trainer + + def cluster_engine(args): def master(): from paddlerec.core.engine.cluster.cluster import ClusterEngine