提交 840c310a 编写于 作者: T tangwei

update code

上级 45ec57e9
......@@ -11,3 +11,4 @@ class Engine:
@abc.abstractmethod
def run(self):
pass
......@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import abc
import time
import yaml
from paddle import fluid
from fleetrec.core.utils import envs
class Trainer(object):
......@@ -78,3 +82,18 @@ class Trainer(object):
self.context_process(self._context)
if self._context['is_exit']:
break
def user_define_engine(engine_yaml):
with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
envs.set_runtime_envions(_config)
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
return trainer_class
......@@ -13,7 +13,6 @@
# limitations under the License.
train:
threads: 12
epochs: 10
reader:
......
trainer:
trainer: "/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads: 4
# for cluster training
communicator:
strategy: "async"
send_queue_size: 4
min_send_grad_num_before_recv: 4
thread_pool_size: 5
max_merge_var_num: 4
......@@ -9,115 +9,122 @@ from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.utils import util
engines = {"TRAINSPILER": {}, "PSLIB": {}}
def run(model_yaml):
trainer = TrainerFactory.create(model_yaml)
trainer.run()
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
def get_engine(engine):
engine = engine.upper()
if version.is_transpiler():
run_engine = engines["TRAINSPILER"].get(engine, None)
else:
run_engine = engines["PSLIB"].get(engine, None)
if run_engine is None:
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER")
return run_engine
def single_engine(args):
print("use SingleTraining to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"}
def single_engine(single_envs, model_yaml):
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
run(model_yaml)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_mpi_engine(args):
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
def local_cluster_engine(cluster_envs, model_yaml):
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
trainer = TrainerFactory.create(args.model)
return trainer
def local_cluster_engine(args):
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalClusterEngine(cluster_envs, model_yaml)
launch.run()
launch = LocalClusterEngine(cluster_envs, args.model)
return launch
def local_mpi_engine(cluster_envs, model_yaml):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalMPIEngine(cluster_envs, model_yaml)
launch.run()
def local_mpi_engine(args):
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
def yaml_engine(engine_yaml, model_yaml):
with open(engine_yaml, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
assert _config is not None
mpi = util.run_which("mpirun")
if not mpi:
raise RuntimeError("can not find mpirun, please check environment")
envs.set_global_envs(_config)
cluster_envs = {"mpirun": mpi, "train.trainer": "CtrTraining", "log_dir": "logs"}
train_location = envs.get_global_env("engine.file")
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
trainer = trainer_class(model_yaml)
trainer.run()
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
#
# def yaml_engine(engine_yaml, model_yaml):
# with open(engine_yaml, 'r') as rb:
# _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
# assert _config is not None
#
# envs.set_global_envs(_config)
#
# train_location = envs.get_global_env("engine.file")
# train_dirname = os.path.dirname(train_location)
# base_name = os.path.splitext(os.path.basename(train_location))[0]
# sys.path.append(train_dirname)
# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
# trainer = trainer_class(model_yaml)
# return trainer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("--model", type=str)
parser.add_argument("--engine", type=str)
parser.add_argument("--engine_extras", type=str)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str)
parser.add_argument("-ex", "--engine_extras", type=str)
args = parser.parse_args()
if not os.path.exists(args.model) or not os.path.isfile(args.model):
raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model))
if args.engine.upper() == "SINGLE":
if version.is_transpiler():
print("use SingleTraining to run model: {}".format(args.model))
single_envs = {"train.trainer": "SingleTraining"}
single_engine(single_envs, args.model)
else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "LOCAL_CLUSTER":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
if version.is_transpiler():
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "ClusterTraining"
cluster_envs["train.strategy.mode"] = "async"
local_cluster_engine(cluster_envs, args.model)
else:
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"}
local_mpi_engine(cluster_envs, args.model)
elif args.engine.upper() == "CLUSTER":
print("launch ClusterTraining with cluster to run model: {}".format(args.model))
if version.is_transpiler():
print("use ClusterTraining to run model: {}".format(args.model))
cluster_envs = {"train.trainer": "ClusterTraining"}
envs.set_runtime_envions(cluster_envs)
else:
cluster_envs = {"train.trainer": "CtrTraining"}
envs.set_runtime_envions(cluster_envs)
run(args.model)
elif args.engine.upper() == "USER_DEFINE":
engine_file = args.engine_extras
if not os.path.exists(engine_file) or not os.path.isfile(engine_file):
raise ValueError(
"argument engine: user_define error, must specify a existed yaml file".format(args.engine_file))
yaml_engine(engine_file, args.model)
else:
raise ValueError("engine only support SINGLE/LOCAL_CLUSTER/CLUSTER/USER_DEFINE")
raise ValueError("argument model: {} error, must specify an existed YAML file".format(args.model))
which_engine = get_engine(args.engine)
engine = which_engine(args)
engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册