提交 840c310a 编写于 作者: T tangwei

update code

上级 45ec57e9
...@@ -11,3 +11,4 @@ class Engine: ...@@ -11,3 +11,4 @@ class Engine:
@abc.abstractmethod @abc.abstractmethod
def run(self): def run(self):
pass pass
...@@ -12,11 +12,15 @@ ...@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
import abc import abc
import time import time
import yaml import yaml
from paddle import fluid from paddle import fluid
from fleetrec.core.utils import envs
class Trainer(object): class Trainer(object):
...@@ -78,3 +82,18 @@ class Trainer(object): ...@@ -78,3 +82,18 @@ class Trainer(object):
self.context_process(self._context) self.context_process(self._context)
if self._context['is_exit']: if self._context['is_exit']:
break break
def user_define_engine(engine_yaml):
with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_runtime_envions(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
return trainer_class
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
train: train:
threads: 12
epochs: 10 epochs: 10
reader: reader:
......
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
communicator:
strategy: "async"
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
...@@ -9,115 +9,122 @@ from fleetrec.core.factory import TrainerFactory ...@@ -9,115 +9,122 @@ from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.utils import util from fleetrec.core.utils import util
engines = {"TRAINSPILER": {}, "PSLIB": {}}
def run(model_yaml):
trainer = TrainerFactory.create(model_yaml)
trainer.run()
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
def get_engine(engine):
engine = engine.upper()
if version.is_transpiler():
run_engine = engines["TRAINSPILER"].get(engine, None)
else:
run_engine = engines["PSLIB"].get(engine, None)
if run_engine is None:
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER")
return run_engine
def single_engine(args):
print("use SingleTraining to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"}
def single_engine(single_envs, model_yaml):
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs) envs.set_runtime_envions(single_envs)
run(model_yaml)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_mpi_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
def local_cluster_engine(cluster_envs, model_yaml): cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def local_cluster_engine(args):
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
launch = LocalClusterEngine(cluster_envs, model_yaml)
launch.run()
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
def local_mpi_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) def local_mpi_engine(args):
envs.set_runtime_envions(cluster_envs) from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
launch = LocalMPIEngine(cluster_envs, model_yaml)
launch.run()
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
def yaml_engine(engine_yaml, model_yaml): mpi = util.run_which("mpirun")
with open(engine_yaml, 'r') as rb: if not mpi:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader) raise RuntimeError("can not find mpirun, please check environment")
assert _config is not None
envs.set_global_envs(_config) cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"}
train_location = envs.get_global_env("engine.file") print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
train_dirname = os.path.dirname(train_location) envs.set_runtime_envions(cluster_envs)
base_name = os.path.splitext(os.path.basename(train_location))[0] launch = LocalMPIEngine(cluster_envs, args.model)
sys.path.append(train_dirname) return launch
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
trainer = trainer_class(model_yaml)
trainer.run() #
# def yaml_engine(engine_yaml, model_yaml):
# with open(engine_yaml, 'r') as rb:
# _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
# assert _config is not None
#
# envs.set_global_envs(_config)
#
# train_location = envs.get_global_env("engine.file")
# train_dirname = os.path.dirname(train_location)
# base_name = os.path.splitext(os.path.basename(train_location))[0]
# sys.path.append(train_dirname)
# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
# trainer = trainer_class(model_yaml)
# return trainer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("--engine", type=str) parser.add_argument("-e", "--engine", type=str)
parser.add_argument("--engine_extras", type=str) parser.add_argument("-ex", "--engine_extras", type=str)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.model) or not os.path.isfile(args.model): if not os.path.exists(args.model) or not os.path.isfile(args.model):
raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model)) raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))
if args.engine.upper() == "SINGLE": which_engine = get_engine(args.engine)
if version.is_transpiler(): engine = which_engine(args)
print("use SingleTraining to run model: {}".format(args.model)) engine.run()
single_envs = {"train.trainer": "SingleTraining"}
single_engine(single_envs, args.model)
else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
if version.is_transpiler():
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
local_cluster_engine(cluster_envs, args.model)
else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
if version.is_transpiler():
print("use ClusterTraining to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
else:
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
run(args.model)
elif args.engine.upper() == "USER_DEFINE":
engine_file = args.engine_extras
if not os.path.exists(engine_file) or not os.path.isfile(engine_file):
raise ValueError(
"argument engine: user_define error, must specify a existed yaml file".format(args.engine_file))
yaml_engine(engine_file, args.model)
else:
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册