提交 840c310a 编写于 作者: T tangwei

update code

上级 45ec57e9
...@@ -11,3 +11,4 @@ class Engine: ...@@ -11,3 +11,4 @@ class Engine:
@abc.abstractmethod @abc.abstractmethod
def run(self): def run(self):
pass pass
...@@ -12,11 +12,15 @@ ...@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
import abc import abc
import time import time
import yaml import yaml
from paddle import fluid from paddle import fluid
from fleetrec.core.utils import envs
class Trainer(object): class Trainer(object):
...@@ -78,3 +82,18 @@ class Trainer(object): ...@@ -78,3 +82,18 @@ class Trainer(object):
self.context_process(self._context) self.context_process(self._context)
if self._context['is_exit']: if self._context['is_exit']:
break break
def user_define_engine(engine_yaml):
with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_runtime_envions(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
return trainer_class
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
train: train:
threads: 12
epochs: 10 epochs: 10
reader: reader:
......
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
communicator:
strategy: "async"
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
...@@ -9,80 +9,62 @@ from fleetrec.core.factory import TrainerFactory ...@@ -9,80 +9,62 @@ from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.utils import util from fleetrec.core.utils import util
engines = {"TRAINSPILER": {}, "PSLIB": {}}
def run(model_yaml):
trainer = TrainerFactory.create(model_yaml)
trainer.run()
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
def single_engine(single_envs, model_yaml):
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
run(model_yaml)
def local_cluster_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value"))) def get_engine(engine):
envs.set_runtime_envions(cluster_envs) engine = engine.upper()
launch = LocalClusterEngine(cluster_envs, model_yaml) if version.is_transpiler():
launch.run() run_engine = engines["TRAINSPILER"].get(engine, None)
else:
run_engine = engines["PSLIB"].get(engine, None)
if run_engine is None:
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER")
return run_engine
def local_mpi_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) def single_engine(args):
envs.set_runtime_envions(cluster_envs) print("use SingleTraining to run model: {}".format(args.model))
launch = LocalMPIEngine(cluster_envs, model_yaml) single_envs = {"train.trainer": "SingleTraining"}
launch.run()
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
def yaml_engine(engine_yaml, model_yaml): trainer = TrainerFactory.create(args.model)
with open(engine_yaml, 'r') as rb: return trainer
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_global_envs(_config)
train_location = envs.get_global_env("engine.file") def cluster_engine(args):
train_dirname = os.path.dirname(train_location) print("launch ClusterTraining with cluster to run model: {}".format(args.model))
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
trainer = trainer_class(model_yaml)
trainer.run()
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("--model", type=str)
parser.add_argument("--engine", type=str)
parser.add_argument("--engine_extras", type=str)
args = parser.parse_args() def cluster_mpi_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
if not os.path.exists(args.model) or not os.path.isfile(args.model): cluster_envs = {"train.trainer": "CtrTraining"}
raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model)) envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
if args.engine.upper() == "SINGLE":
if version.is_transpiler():
print("use SingleTraining to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"}
single_engine(single_envs, args.model)
else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun") def local_cluster_engine(args):
if not mpi_path: from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
if version.is_transpiler():
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
...@@ -91,33 +73,58 @@ if __name__ == "__main__": ...@@ -91,33 +73,58 @@ if __name__ == "__main__":
cluster_envs["train.trainer"] = "ClusterTraining" cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async" cluster_envs["train.strategy.mode"] = "async"
local_cluster_engine(cluster_envs, args.model) print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
else: envs.set_runtime_envions(cluster_envs)
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
def local_mpi_engine(args):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun") mpi = util.run_which("mpirun")
if not mpi_path: if not mpi:
raise RuntimeError("can not find mpirun, please check environment") raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"} cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
if version.is_transpiler(): print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
print("use ClusterTraining to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
else:
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
#
# def yaml_engine(engine_yaml, model_yaml):
# with open(engine_yaml, 'r') as rb:
# _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
# assert _config is not None
#
# envs.set_global_envs(_config)
#
# train_location = envs.get_global_env("engine.file")
# train_dirname = os.path.dirname(train_location)
# base_name = os.path.splitext(os.path.basename(train_location))[0]
# sys.path.append(train_dirname)
# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
# trainer = trainer_class(model_yaml)
# return trainer
run(args.model)
elif args.engine.upper() == "USER_DEFINE": if __name__ == "__main__":
engine_file = args.engine_extras parser = argparse.ArgumentParser(description='fleet-rec run')
if not os.path.exists(engine_file) or not os.path.isfile(engine_file): parser.add_argument("-m", "--model", type=str)
raise ValueError( parser.add_argument("-e", "--engine", type=str)
"argument engine: user_define error, must specify a existed yaml file".format(args.engine_file)) parser.add_argument("-ex", "--engine_extras", type=str)
yaml_engine(engine_file, args.model)
else: args = parser.parse_args()
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")
if not os.path.exists(args.model) or not os.path.isfile(args.model):
raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))
which_engine = get_engine(args.engine)
engine = which_engine(args)
engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册