提交 4a73a159 编写于 作者: T tangwei12

add online training

上级 2101532f
......@@ -37,7 +37,8 @@ def trainer_registry():
trainer_abs, "tdm_single_trainer.py")
trainers["TDMClusterTrainer"] = os.path.join(
trainer_abs, "tdm_cluster_trainer.py")
trainers["OnlineLearningTrainer"] = os.path.join(
trainer_abs, "online_learning_trainer.py")
trainer_registry()
......
......@@ -39,6 +39,7 @@ def engine_registry():
engines["TRANSPILER"]["SINGLE"] = single_engine
engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRANSPILER"]["CLUSTER"] = cluster_engine
engines["TRANSPILER"]["ONLINE_LEARNING"] = online_learning
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
......@@ -125,6 +126,19 @@ def single_engine(args):
trainer = TrainerFactory.create(args.model)
return trainer
def online_learning(args):
trainer = "OnlineLearningTrainer"
single_envs = {}
single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "online_learning"
single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model))
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args):
def update_workspace(cluster_envs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册