未验证 提交 8c68cdf8 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #176 from MrChengmo/online_training

Support Online training
...@@ -22,6 +22,19 @@ trainers = {} ...@@ -22,6 +22,19 @@ trainers = {}
def trainer_registry(): def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py")
trainers["ClusterTrainer"] = os.path.join(trainer_abs,
"cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs,
"ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(trainer_abs,
"ctr_modul_trainer.py")
trainers["TDMSingleTrainer"] = os.path.join(trainer_abs,
"tdm_single_trainer.py")
trainers["TDMClusterTrainer"] = os.path.join(trainer_abs,
"tdm_cluster_trainer.py")
trainers["OnlineLearningTrainer"] = os.path.join(
trainer_abs, "online_learning_trainer.py")
# Definition of procedure execution process # Definition of procedure execution process
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, trainers["CtrCodingTrainer"] = os.path.join(trainer_abs,
"ctr_coding_trainer.py") "ctr_coding_trainer.py")
......
...@@ -16,7 +16,6 @@ import os ...@@ -16,7 +16,6 @@ import os
import subprocess import subprocess
import sys import sys
import argparse import argparse
import tempfile
import warnings import warnings
import copy import copy
...@@ -39,6 +38,7 @@ def engine_registry(): ...@@ -39,6 +38,7 @@ def engine_registry():
engines["TRANSPILER"]["INFER"] = single_infer_engine engines["TRANSPILER"]["INFER"] = single_infer_engine
engines["TRANSPILER"]["LOCAL_CLUSTER_TRAIN"] = local_cluster_engine engines["TRANSPILER"]["LOCAL_CLUSTER_TRAIN"] = local_cluster_engine
engines["TRANSPILER"]["CLUSTER_TRAIN"] = cluster_engine engines["TRANSPILER"]["CLUSTER_TRAIN"] = cluster_engine
engines["TRANSPILER"]["ONLINE_LEARNING"] = online_learning
engines["PSLIB"]["TRAIN"] = local_mpi_engine engines["PSLIB"]["TRAIN"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER_TRAIN"] = local_mpi_engine engines["PSLIB"]["LOCAL_CLUSTER_TRAIN"] = local_mpi_engine
engines["PSLIB"]["CLUSTER_TRAIN"] = cluster_mpi_engine engines["PSLIB"]["CLUSTER_TRAIN"] = cluster_mpi_engine
...@@ -259,6 +259,20 @@ def single_infer_engine(args): ...@@ -259,6 +259,20 @@ def single_infer_engine(args):
return trainer return trainer
def online_learning(args):
trainer = "OnlineLearningTrainer"
single_envs = {}
single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "online_learning"
single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model))
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args): def cluster_engine(args):
def master(): def master():
from paddlerec.core.engine.cluster.cluster import ClusterEngine from paddlerec.core.engine.cluster.cluster import ClusterEngine
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册