提交 fd5e7f94 编写于 作者: T tangwei

add mpi engine

上级 0e42fd80
import abc
class Engine:
__metaclass__ = abc.ABCMeta
def __init__(self, envs, trainer):
self.envs = envs
self.trainer = trainer
@abc.abstractmethod
def run(self):
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import unicode_literals
import subprocess
import sys
import os
import copy
from fleetrec.core.engine.engine import Engine
class LocalClusterEngine(Engine):
def start_procs(self):
worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"]
start_port = self.envs["start_port"]
logs_dir = self.envs["log_dir"]
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env["CLUSTER_INSTANCE"] = "1"
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
factory = "fleetrec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, self.trainer]
for i in range(server_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_PORT": user_endpoints_port[i],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": user_endpoints_ips[i]
})
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
for i in range(worker_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
})
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
# only wait worker to finish here
for i, proc in enumerate(procs):
if i < server_num:
continue
procs[i].wait()
if len(log_fns) > 0:
log_fns[i].close()
for i in range(server_num):
if len(log_fns) > 0:
log_fns[i].close()
procs[i].terminate()
print("all workers and parameter servers already completed", file=sys.stderr)
def run(self):
self.start_procs()
......@@ -19,82 +19,38 @@ import sys
import os
import copy
from fleetrec.core.engine.engine import Engine
def start_procs(args, yaml):
worker_num = args["worker_num"]
server_num = args["server_num"]
start_port = args["start_port"]
logs_dir = args["log_dir"]
class LocalMPIEngine(Engine):
def start_procs(self):
logs_dir = self.envs["log_dir"]
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env["CLUSTER_INSTANCE"] = "1"
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
factory = "fleetrec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, yaml]
for i in range(server_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_PORT": user_endpoints_port[i],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": user_endpoints_ips[i]
})
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
for i in range(worker_num):
current_env.update({
"PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
})
mpi_cmd = "mpirun -npernode 2 -timestamp-output -tag-output".split(" ")
cmd = mpi_cmd.extend([sys.executable, "-u", "-m", factory, self.trainer])
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
fn = open("%s/job.log" % logs_dir, "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
# only wait worker to finish here
for i, proc in enumerate(procs):
if i < server_num:
continue
procs[i].wait()
if len(log_fns) > 0:
log_fns[i].close()
for i in range(server_num):
for i in range(len(procs)):
if len(log_fns) > 0:
log_fns[i].close()
procs[i].terminate()
print("all workers and parameter servers already completed", file=sys.stderr)
class Launch:
def __init__(self, envs, trainer):
self.envs = envs
self.trainer = trainer
def run(self):
start_procs(self.envs, self.trainer)
self.start_procs()
......@@ -30,6 +30,20 @@ def str2bool(v):
raise ValueError('Boolean value expected.')
def run_which(command):
regex = "/usr/bin/which: no {} in"
ret = run_shell_cmd("which {}".format(command))
if ret.startswith(regex.format(command)):
return None
else:
return ret
def run_shell_cmd(command):
assert command is not None and isinstance(command, str)
return os.popen(command).read().strip()
def get_env_value(env_name):
"""
get os environment value
......
......@@ -7,7 +7,9 @@ from paddle.fluid.incubate.fleet.parameter_server import version
from fleetrec.core.factory import TrainerFactory
from fleetrec.core.utils import envs
from fleetrec.core.engine import local_engine
from fleetrec.core.utils import util
from fleetrec.core.engine.local_cluster_engine import LocalClusterEngine
from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine
def run(model_yaml):
......@@ -25,23 +27,25 @@ def single_engine(single_envs, model_yaml):
def local_cluster_engine(cluster_envs, model_yaml):
print(envs.pretty_print_envs(cluster_envs, ("Local Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
launch = local_engine.Launch(cluster_envs, model_yaml)
launch = LocalClusterEngine(cluster_envs, model_yaml)
launch.run()
def local_mpi_engine(model_yaml):
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {}
cluster_envs["server_num"] = 1
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "CtrTraining"
mpi_path = util.run_which("mpirun")
if not mpi_path:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"}
print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
print("coming soon")
launch = LocalMPIEngine(cluster_envs, model_yaml)
launch.run()
def yaml_engine(engine_yaml, model_yaml):
......@@ -55,7 +59,7 @@ def yaml_engine(engine_yaml, model_yaml):
train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0]
sys.path.append(train_dirname)
trainer_class = envs.lazy_instance(base_name, "UserDefineTrainer")
trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
trainer = trainer_class(model_yaml)
trainer.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册