提交 29d464d3 编写于 作者: T tangwei

update code

上级 3a3fe12f
...@@ -19,6 +19,19 @@ import yaml ...@@ -19,6 +19,19 @@ import yaml
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers")
trainers = {}
def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py")
trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py")
trainer_registry()
class TrainerFactory(object): class TrainerFactory(object):
def __init__(self): def __init__(self):
...@@ -28,26 +41,21 @@ class TrainerFactory(object): ...@@ -28,26 +41,21 @@ class TrainerFactory(object):
def _build_trainer(yaml_path): def _build_trainer(yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs())) print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_training_mode() train_mode = envs.get_trainer()
trainer_abs = trainers.get(train_mode, None)
if train_mode == "SingleTraining":
from fleetrec.core.trainers.single_trainer import SingleTrainer if trainer_abs is None:
trainer = SingleTrainer(yaml_path) if not os.path.exists(train_mode) or os.path.isfile(train_mode):
elif train_mode == "ClusterTraining": raise ValueError("trainer {} can not be recognized")
from fleetrec.core.trainers.cluster_trainer import ClusterTrainer trainer_abs = train_mode
trainer = ClusterTrainer(yaml_path) train_mode = "UserDefineTrainer"
elif train_mode == "CtrTraining":
from fleetrec.core.trainers.ctr_coding_trainer import CtrPaddleTrainer train_location = envs.get_global_env("train.location")
trainer = CtrPaddleTrainer(yaml_path) train_dirname = os.path.dirname(trainer_abs)
elif train_mode == "UserDefineTraining": base_name = os.path.splitext(os.path.basename(train_location))[0]
train_location = envs.get_global_env("train.location") sys.path.append(train_dirname)
train_dirname = os.path.dirname(train_location) trainer_class = envs.lazy_instance(base_name, train_mode)
base_name = os.path.splitext(os.path.basename(train_location))[0] trainer = trainer_class(yaml_path)
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer = trainer_class(yaml_path)
else:
raise ValueError("trainer only support SingleTraining/ClusterTraining")
return trainer return trainer
@staticmethod @staticmethod
......
...@@ -29,11 +29,8 @@ def get_runtime_envion(key): ...@@ -29,11 +29,8 @@ def get_runtime_envion(key):
return os.getenv(key, None) return os.getenv(key, None)
def get_training_mode(): def get_trainer():
train_mode = get_global_env("train.trainer") train_mode = get_runtime_envion("trainer.trainer")
if train_mode is None:
train_mode = get_runtime_envion("train.trainer")
return train_mode return train_mode
......
...@@ -60,12 +60,12 @@ class Model(ModelBase): ...@@ -60,12 +60,12 @@ class Model(ModelBase):
self._data_var.append(self.label_input) self._data_var.append(self.label_input)
def net(self): def net(self):
train_mode = envs.get_training_mode() trainer = envs.get_trainer()
is_distributed = True if train_mode == "CtrTraining" else False is_distributed = True if trainer == "CtrTrainer" else False
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace) sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace)
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace) sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace)
sparse_feature_dim = 9 if train_mode == "CtrTraining" else sparse_feature_dim sparse_feature_dim = 9 if trainer == "CtrTrainer" else sparse_feature_dim
def embedding_layer(input): def embedding_layer(input):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
......
import argparse import argparse
import os import os
import sys
import yaml import yaml
from paddle.fluid.incubate.fleet.parameter_server import version from paddle.fluid.incubate.fleet.parameter_server import version
...@@ -10,6 +9,19 @@ from fleetrec.core.utils import envs ...@@ -10,6 +9,19 @@ from fleetrec.core.utils import envs
from fleetrec.core.utils import util from fleetrec.core.utils import util
engines = {"TRAINSPILER": {}, "PSLIB": {}} engines = {"TRAINSPILER": {}, "PSLIB": {}}
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None:
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
if cluster_envs is None:
cluster_envs = {}
cluster_envs.update(_envs)
envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
def engine_registry(): def engine_registry():
...@@ -34,35 +46,38 @@ def get_engine(engine): ...@@ -34,35 +46,38 @@ def get_engine(engine):
def single_engine(args): def single_engine(args):
print("use SingleTraining to run model: {}".format(args.model)) print("use single engine to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"} single_envs = {"trainer.trainer": "SingleTraining"}
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) set_runtime_envs(single_envs, args.engine_extras)
envs.set_runtime_envions(single_envs)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
def cluster_engine(args): def cluster_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTraining"}
set_runtime_envs(cluster_envs, args.engine_extras)
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
def cluster_mpi_engine(args): def cluster_mpi_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrTraining"}
set_runtime_envs(cluster_envs, args.engine_extras)
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
def local_cluster_engine(args): def local_cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
cluster_envs = {} cluster_envs = {}
...@@ -70,17 +85,17 @@ def local_cluster_engine(args): ...@@ -70,17 +85,17 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining" cluster_envs["trainer.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async" cluster_envs["trainer.strategy.mode"] = "async"
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs)
launch = LocalClusterEngine(cluster_envs, args.model) launch = LocalClusterEngine(cluster_envs, args.model)
return launch return launch
def local_mpi_engine(args): def local_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
...@@ -89,10 +104,8 @@ def local_mpi_engine(args): ...@@ -89,10 +104,8 @@ def local_mpi_engine(args):
if not mpi: if not mpi:
raise RuntimeError("can not find mpirun, please check environment") raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"} cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras)
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalMPIEngine(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)
return launch return launch
...@@ -118,13 +131,21 @@ if __name__ == "__main__": ...@@ -118,13 +131,21 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str) parser.add_argument("-e", "--engine", type=str)
parser.add_argument("-ex", "--engine_extras", type=str) parser.add_argument("-ex", "--engine_extras", default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.model) or not os.path.isfile(args.model): if not os.path.exists(args.model) or not os.path.isfile(args.model):
raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model)) raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))
if args.engine.upper() not in clusters:
raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))
if args.engine_extras is not None:
if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras):
raise ValueError(
"argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras))
which_engine = get_engine(args.engine) which_engine = get_engine(args.engine)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册