提交 29d464d3 编写于 作者: T tangwei

update code

上级 3a3fe12f
......@@ -19,6 +19,19 @@ import yaml
from fleetrec.core.utils import envs
trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers")
trainers = {}
def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py")
trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py")
trainer_registry()
class TrainerFactory(object):
def __init__(self):
......@@ -28,26 +41,21 @@ class TrainerFactory(object):
def _build_trainer(yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_training_mode()
if train_mode == "SingleTraining":
from fleetrec.core.trainers.single_trainer import SingleTrainer
trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining":
from fleetrec.core.trainers.cluster_trainer import ClusterTrainer
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTraining":
from fleetrec.core.trainers.ctr_coding_trainer import CtrPaddleTrainer
trainer = CtrPaddleTrainer(yaml_path)
elif train_mode == "UserDefineTraining":
train_mode = envs.get_trainer()
trainer_abs = trainers.get(train_mode, None)
if trainer_abs is None:
if not os.path.exists(train_mode) or os.path.isfile(train_mode):
raise ValueError("trainer {} can not be recognized")
trainer_abs = train_mode
train_mode = "UserDefineTrainer"
train_location = envs.get_global_env("train.location")
train_dirname = os.path.dirname(train_location)
train_dirname = os.path.dirname(trainer_abs)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer_class = envs.lazy_instance(base_name, train_mode)
trainer = trainer_class(yaml_path)
else:
raise ValueError("trainer only support SingleTraining/ClusterTraining")
return trainer
@staticmethod
......
......@@ -29,11 +29,8 @@ def get_runtime_envion(key):
return os.getenv(key, None)
def get_training_mode():
train_mode = get_global_env("train.trainer")
if train_mode is None:
train_mode = get_runtime_envion("train.trainer")
def get_trainer():
train_mode = get_runtime_envion("trainer.trainer")
return train_mode
......
......@@ -60,12 +60,12 @@ class Model(ModelBase):
self._data_var.append(self.label_input)
def net(self):
train_mode = envs.get_training_mode()
trainer = envs.get_trainer()
is_distributed = True if train_mode == "CtrTraining" else False
is_distributed = True if trainer == "CtrTrainer" else False
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self.namespace)
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self.namespace)
sparse_feature_dim = 9 if train_mode == "CtrTraining" else sparse_feature_dim
sparse_feature_dim = 9 if trainer == "CtrTrainer" else sparse_feature_dim
def embedding_layer(input):
emb = fluid.layers.embedding(
......
import argparse
import os
import sys
import yaml
from paddle.fluid.incubate.fleet.parameter_server import version
......@@ -10,6 +9,19 @@ from fleetrec.core.utils import envs
from fleetrec.core.utils import util
engines = {"TRAINSPILER": {}, "PSLIB": {}}
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None:
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
if cluster_envs is None:
cluster_envs = {}
cluster_envs.update(_envs)
envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
def engine_registry():
......@@ -34,35 +46,38 @@ def get_engine(engine):
def single_engine(args):
print("use SingleTraining to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"}
print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTraining"}
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTraining"}
set_runtime_envs(cluster_envs, args.engine_extras)
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_mpi_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrTraining"}
set_runtime_envs(cluster_envs, args.engine_extras)
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def local_cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
cluster_envs = {}
......@@ -70,17 +85,17 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
cluster_envs["trainer.trainer"] = "ClusterTraining"
cluster_envs["trainer.strategy.mode"] = "async"
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
def local_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
......@@ -89,10 +104,8 @@ def local_mpi_engine(args):
if not mpi:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"}
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
......@@ -118,13 +131,21 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str)
parser.add_argument("-ex", "--engine_extras", type=str)
parser.add_argument("-ex", "--engine_extras", default=None, type=str)
args = parser.parse_args()
if not os.path.exists(args.model) or not os.path.isfile(args.model):
raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))
if args.engine.upper() not in clusters:
raise ValueError("argument engine: {} error, must in {}".format(args.engine, clusters))
if args.engine_extras is not None:
if not os.path.exists(args.engine_extras) or not os.path.isfile(args.engine_extras):
raise ValueError(
"argument engine_extras: {} error, must specify an existed YAML file".format(args.engine_extras))
which_engine = get_engine(args.engine)
engine = which_engine(args)
engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册