提交 34bd0045 编写于 作者: M MrChengmo

refine fleetrun.ps_launch

上级 f4c750d7
......@@ -112,10 +112,6 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
......@@ -125,10 +121,12 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
......
......@@ -508,7 +508,7 @@ class RoleMakerBase(object):
and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainerr
"""
assert self._heter_trainer_endpoints != []
return self._heter_trainer_endpoints[(self._current_id + 1) %
return self._heter_trainer_endpoints[(self._current_id) %
self._heter_worker_num()]
def _get_heter_worker_device(self):
......
......@@ -89,14 +89,40 @@ def _parse_args():
description='''start paddle training using multi-process mode.
see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
''')
base_group = parser.add_argument_group("Base Parameters")
base_group.add_argument(
"-d",
"--distributed_mode",
type=str,
choices=["collective", "ps", "ps_heter", "ps_gpu", ""],
default="",
help="Distributed running mode: collective/ps/ps_gpu/ps_heter")
base_group.add_argument(
"--log_dir",
type=str,
default="log",
help="The path for each process's log.If it's not set, the log will printed to default pipe."
)
base_group.add_argument(
"training_script",
type=str,
help="The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# Optional arguments for the launch helper
parser.add_argument(
# for collective
collective_group = parser.add_argument_group("Collective Parameters")
collective_group.add_argument(
"--ips",
type=str,
default="127.0.0.1",
help="Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..")
parser.add_argument(
collective_group.add_argument(
"--gpus",
type=str,
default=None,
......@@ -104,31 +130,30 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"each process is bound to a single GPU. And if it's not set, this module will use all the gpu cards for training."
)
parser.add_argument(
ps_group = parser.add_argument_group("Parameter-Server Parameters")
# for parameter server
ps_group.add_argument(
"--servers", type=str, default="", help="User defined servers ip:port")
parser.add_argument(
ps_group.add_argument(
"--workers", type=str, default="", help="User defined workers ip:port")
parser.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument(
"--heter_workers",
type=str,
default="",
help="User defined heter workers ip:port")
parser.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers")
parser.add_argument(
"--log_dir",
ps_group.add_argument(
"--heter_worker_device",
type=str,
default="log",
help="The path for each process's log.If it's not set, the log will printed to default pipe."
)
# positional
parser.add_argument(
"training_script",
type=str,
help="The full path to the single GPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
default="gpu",
choices=["gpu", "xpu"],
help="heter worker device")
# rest from the training program
parser.add_argument('training_script_args', nargs=REMAINDER)
return parser.parse_args()
......@@ -246,209 +271,32 @@ def launch_collective(args):
def launch_ps(args):
ports = None
start_port = 6170
if args.server_num:
server_num = args.server_num
ports = get_ports(server_num, 0)
server_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
else:
assert args.servers != "", "The setting of CPU mode must be either server_num or servers."
server_endpoints = args.servers
server_endpoints_ips = [
x.strip().split(":")[0] for x in server_endpoints.split(",")
]
server_endpoints_port = [
x.strip().split(":")[1] for x in server_endpoints.split(",")
]
server_num = len(server_endpoints_ips)
cloud_flag = cloud_utils.use_paddlecloud()
if args.worker_num:
worker_num = args.worker_num
ports = get_ports(worker_num, server_num)
worker_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
else:
assert args.workers != "", "The setting of CPU mode must be either worker_num or workers."
worker_endpoints = args.workers
worker_endpoints_ips = [
x.strip().split(":")[0] for x in worker_endpoints.split(",")
]
worker_num = len(worker_endpoints_ips)
node_ips = list(set(server_endpoints_ips + worker_endpoints_ips))
worker_endpoints_len = [
len(x.strip().split(":")) for x in worker_endpoints.split(",")
]
if 1 in worker_endpoints_len:
# if no port value in worker_endpoints, will set default port values.
worker_endpoints_port = range(start_port + server_num,
start_port + server_num + worker_num, 1)
else:
worker_endpoints_port = [
x.strip().split(":")[1] for x in worker_endpoints.split(",")
]
# for ps-cpu on paddlecloud
direct_start_mode = ["ps", ""]
if cloud_flag and (args.distributed_mode in direct_start_mode):
direct_start(args)
return
elif cloud_flag and args.distributed_mode == "ps_heter":
cloud_ps_heter_env_set(args)
args.trainers = os.getenv("PADDLE_TRAINER_ENDPOINTS")
args.workers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST")
# local train
if len(set(node_ips)) == 1:
current_node_ip = node_ips[0]
else:
_, current_node_ip = get_host_name_ip()
assert current_node_ip in node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (current_node_ip, node_ips)
node_rank = node_ips.index(current_node_ip)
logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}".
format(node_ips, current_node_ip, node_rank, server_endpoints_port))
cluster = Cluster(hdfs=None)
server_rank = 0
worker_rank = 0
for node_rank, ip in enumerate(node_ips):
pod = Pod()
pod.rank = node_rank
pod.addr = ip
for i in range(len(server_endpoints_ips)):
if ip == server_endpoints_ips[i]:
server = Trainer()
server.endpoint = "%s:%s" % (ip, server_endpoints_port[i])
server.rank = server_rank
server_rank += 1
pod.servers.append(server)
for j in range(len(worker_endpoints_ips)):
if ip == worker_endpoints_ips[j]:
worker = Trainer()
worker.endpoint = "%s:%s" % (ip, worker_endpoints_port[i])
worker.rank = worker_rank
worker_rank += 1
pod.workers.append(worker)
cluster.pods.append(pod)
pod_rank = node_ips.index(current_node_ip)
pod = cluster.pods[pod_rank]
default_env = os.environ.copy()
current_env = copy.copy(default_env)
gloo_rendezvous_dir = tempfile.mkdtemp()
# add gloo env
current_env["PADDLE_WITH_GLOO"] = "1"
current_env["PADDLE_GLOO_RENDEZVOUS"] = "2"
current_env["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
cmds = []
log_fns = []
for idx, cur_server in enumerate(pod.servers):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1"
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
cmds.append(cmd)
if idx == 0:
logger.info(
"Local server start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.servers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/serverlog.%d" % (args.log_dir, idx), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_server.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
procs.append(tp)
for idx, cur_worker in enumerate(pod.workers):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1"
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
cmds.append(cmd)
if idx == 0:
logger.info(
"Local worker start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.workers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/workerlog.%d" % (args.log_dir, idx), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_worker.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
procs.append(tp)
logger.info(
"Please check servers and workers logs in {}/workerlog.* and {}/serverlog.*".
format(args.log_dir, args.log_dir))
# only wait worker to finish here
for i, proc in enumerate(procs):
if i < len(pod.servers):
continue
procs[i].proc.wait()
if len(log_fns) > 0:
log_fns[i].close()
print("all workers exit, going to finish parameter server", file=sys.stderr)
for i in range(len(pod.servers)):
if len(log_fns) > 0:
log_fns[i].close()
procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr)
if os.path.exists(gloo_rendezvous_dir):
shutil.rmtree(gloo_rendezvous_dir)
ps_launcher = ParameterServerLauncher(args)
ps_launcher.start_ps(args)
return
def launch():
args = _parse_args()
logger = get_logger()
_print_arguments(args)
ps_args = ['--worker_num', '--server_num', '--servers', '--workers']
ps_args = [
'--worker_num', '--server_num', '--heter_worker_num', '--servers',
'--workers', '--heter_worrkers', 'heter_worker_device'
]
collective_args = ['--ips', '--gpus']
has_ps_args = [
ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
......@@ -462,9 +310,10 @@ def launch():
else:
cuda_device_num = 0
if len(has_ps_args) > 0 or cuda_device_num == 0:
ps_mode = ['ps', 'ps_gpu', 'ps_heter']
if len(has_ps_args) > 0 or args.distributed_mode in ps_mode:
logger.info(
"Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}".
"Run parameter-sever mode. pserver arguments:{}, cuda count:{}".
format(has_ps_args, cuda_device_num))
launch_ps(args)
elif len(has_collective_args) > 0:
......
......@@ -21,9 +21,13 @@ import signal
import copy
import sys
import subprocess
import tempfile
import shutil
from contextlib import closing
import socket
import paddle
import paddle.fluid as fluid
logger = logging.getLogger("root")
logger.propagate = False
......@@ -144,14 +148,16 @@ class Pod(object):
self.trainers = []
self.servers = []
self.workers = []
self.heter_workers = []
self.gpus = []
def __str__(self):
return "rank:{} id:{} addr:{} port:{} visible_gpu:{} trainers:{} servers:{} \
workers:{}".format(self.rank, self.id, self.addr, self.port,
self.gpus, [str(t) for t in self.trainers],
[str(s) for s in self.servers],
[str(w) for w in self.workers])
workers:{} heter_workers:{}".format(
self.rank, self.id, self.addr, self.port, self.gpus, [
str(t) for t in self.trainers
], [str(s) for s in self.servers], [str(w) for w in self.workers],
[str(h) for h in self.heter_workers])
def __eq__(self, pod):
if self.rank != pod.rank or \
......@@ -262,7 +268,7 @@ def terminate_local_procs(procs):
p.log_fn.close()
logger.debug("terminate process id:{}".format(p.proc.pid))
#wait all process terminiated
# wait all process terminiated
time.sleep(3)
for step in range(0, 50):
alive = False
......@@ -406,10 +412,10 @@ def start_local_trainers(cluster,
else:
current_env = copy.copy(envs)
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
# paddle broadcast ncclUniqueId use socket, and
# proxy maybe make trainers unreachable, so delete them.
# if we set them to "", grpc will log error message "bad uri"
# so just delete them.
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
......@@ -510,3 +516,450 @@ def watch_local_trainers(procs, nranks):
raise
return alive
def direct_start(args):
# run ps-cpu mode on paddlecloud, using given envs
cmd = [sys.executable, "-u", args.training_script] + \
args.training_script_args
proc = subprocess.Popen(cmd)
proc.wait()
return
def get_custom_endpoints(origin_endpoints, offset=0):
"""
origin_endpoint: ip:port
user_define_endpoint: ip:(port+offset)
"""
assert origin_endpoints != None
paddle_user_define_endpoints_list = []
for ip_port in origin_endpoints.split(","):
ip = ip_port.split(":")[0]
port = ip_port.split(":")[1]
new_port = int(port) + offset
paddle_user_define_endpoints_list.append(":".join((ip, str(new_port))))
paddle_user_define_endpoints = ",".join(paddle_user_define_endpoints_list)
return paddle_user_define_endpoints
def cloud_ps_heter_env_set(args):
environs = {}
paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "")
assert paddle_trainer_endpoints != None
environs["PADDLE_TRAINER_ENDPOINTS"] = paddle_trainer_endpoints
paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "")
assert paddle_pserver_endpoints != None
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints
avilable_ports = os.getenv("TRAINER_PORTS", "").split(",")
assert len(
avilable_ports
) > 3, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit"
# hard code for paddlecloud custom-framework
trainers_num = len(paddle_pserver_endpoints.split(","))
assert trainers_num != 0
environs["PADDLE_TRAINERS_NUM"] = trainers_num
environs["TRAINERS_NUM"] = trainers_num
environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_worker_device
# hard code for paddlecloud custom-framework
environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints(
paddle_pserver_endpoints, 1)
for k, v in environs.items():
os.environ[k] = str(v)
logger.info("Set heter parameter server env: {}".format(
pretty_print_envs(environs)))
class ParameterServerLauncher(object):
def __init__(self, args):
self.server_num = 0
self.worker_num = 0
self.heter_worker_num = 0
self.server_endpoints = []
self.server_endpoints_ips = []
self.server_endpoints_port = []
self.worker_endpoints = []
self.worker_endpoints_ips = []
self.worker_endpoints_port = []
self.heter_worker_endpoints = []
self.heter_worker_endpoints_ips = []
self.heter_worker_endpoints_port = []
self.is_local = True
self.current_node_ip = ""
self.get_role_endpoints(args)
def get_role_endpoints(self, args):
# get server envs
if args.server_num:
self.server_num = args.server_num
if args.servers:
assert len(
args.servers.split(",")
) == self.server_num, "The server_num and servers doesn't match. Expect servers endpoints num epual to server_num, but received servers enpoint num: {} and server_num {}".format(
len(args.servers.split(",")), self.server_num)
self.server_endpoints = args.servers
else:
ports = get_ports(self.server_num, 0)
self.server_endpoints = ",".join(
["127.0.0.1:" + str(x) for x in ports])
else:
assert args.servers != "", "The setting of Parameter-Server must has server_num or servers."
self.server_endpoints = args.servers
self.server_num = len(self.server_endpoints.split(","))
# get worker envs
if args.worker_num:
self.worker_num = args.worker_num
if args.workers:
assert len(
args.workers.split(",")
) == self.worker_num, "The worker_num and workers doesn't match. Expect workers endpoints num epual to worker_num, but received workers enpoint num: {} and worker_num {}".format(
len(args.workers.split(",")), self.worker_num)
self.worker_endpoints = args.workers
else:
ports = get_ports(self.worker_num, self.server_num)
self.worker_endpoints = ",".join(
["127.0.0.1:" + str(x) for x in ports])
else:
assert args.workers != "", "The setting of Parameter-Server must has worker_num or workers."
self.worker_endpoints = args.workers
self.worker_num = len(self.worker_endpoints.split(","))
# get heter worker envs
if args.distributed_mode == "ps_heter":
if args.heter_worker_num:
self.heter_worker_num = args.heter_worker_num
if args.heter_workers:
assert len(
args.heter_workers.split(",")
) == self.heter_worker_num, "The heter_worker_num and heter_workers doesn't match. Expect heter_workers endpoints num epual to heter_worker_num, but received heter_workers enpoint num: {} and heter_worker_num {}".format(
len(args.heter_workers.split(",")),
self.heter_worker_num)
self.heter_worker_endpoints = args.heter_workers
else:
ports = get_ports(self.heter_worker_num,
self.server_num + self.worker_num)
self.heter_worker_endpoints = ",".join(
["127.0.0.1:" + str(x) for x in ports])
else:
assert args.heter_workers != "", "The setting of Parameter-Server heter mode must has heter_worker_num or heter_workers."
self.heter_worker_endpoints = args.heter_workers
self.heter_worker_num = len(
self.heter_worker_endpoints.split(","))
# check local or user define
self.server_endpoints_ips = [
x.strip().split(":")[0] for x in self.server_endpoints.split(",")
]
self.worker_endpoints_ips = [
x.strip().split(":")[0] for x in self.worker_endpoints.split(",")
]
self.server_endpoints_port = [
x.strip().split(":")[1] for x in self.server_endpoints.split(",")
]
self.worker_endpoints_port = [
x.strip().split(":")[1] for x in self.worker_endpoints.split(",")
]
self.node_ips = list(
set(self.server_endpoints_ips + self.worker_endpoints_ips))
if args.distributed_mode == "ps_heter":
self.heter_worker_endpoints_ips = [
x.strip().split(":")[0]
for x in self.heter_worker_endpoints.split(",")
]
self.heter_worker_endpoints_port = [
x.strip().split(":")[1]
for x in self.heter_worker_endpoints.split(",")
]
self.node_ips = list(
set(self.node_ips + self.heter_worker_endpoints_ips))
if len(set(self.node_ips)) == 1:
self.is_local = True
self.current_node_ip = self.node_ips[0]
else:
self.is_local = False
_, self.current_node_ip = get_host_name_ip()
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips)
self.node_rank = self.node_ips.index(self.current_node_ip)
logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}".
format(self.node_ips, self.current_node_ip, self.node_rank))
def start_ps(self, args):
cluster = Cluster(hdfs=None)
server_rank = 0
worker_rank = 0
heter_worker_rank = 0
for node_rank, ip in enumerate(self.node_ips):
pod = Pod()
pod.rank = node_rank
pod.addr = ip
for i in range(len(self.server_endpoints_ips)):
if ip == self.server_endpoints_ips[i]:
server = Trainer()
server.endpoint = "%s:%s" % (ip,
self.server_endpoints_port[i])
server.rank = server_rank
server_rank += 1
pod.servers.append(server)
for j in range(len(self.worker_endpoints_ips)):
if ip == self.worker_endpoints_ips[j]:
worker = Trainer()
worker.endpoint = "%s:%s" % (ip,
self.worker_endpoints_port[j])
worker.rank = worker_rank
worker_rank += 1
pod.workers.append(worker)
for k in range(len(self.heter_worker_endpoints_ips)):
if ip == self.heter_worker_endpoints_ips[k]:
heter_worker = Trainer()
heter_worker.endpoint = "%s:%s" % (
ip,
self.endpoints_dict["heter_worker_endpoints_port"][k])
heter_worker.rank = heter_worker_rank
heter_worker_rank += 1
pod.heter_workers.append(heter_worker)
cluster.pods.append(pod)
pod = cluster.pods[self.node_rank]
self.gloo_rendezvous_dir = tempfile.mkdtemp()
# 3. subproces start
self.procs = []
self.cmds = []
self.log_fns = []
self.start_pod_server(args, pod)
self.start_pod_worker(args, pod)
self.start_pod_heter_worker(args, pod)
logger.info(
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*".
format(args.log_dir, args.log_dir, args.log_dir))
# only wait worker to finish here
for i, proc in enumerate(self.procs):
if i < len(pod.servers) and i > len(pod.servers) + len(pod.workers):
continue
self.procs[i].proc.wait()
if len(self.log_fns) > 0:
self.log_fns[i].close()
print(
"all workers exit, going to finish parameter server and heter_worker",
file=sys.stderr)
for i in range(
len(pod.servers + pod.workers),
len(pod.servers + pod.workers + pod.heter_workers)):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all heter worker are killed", file=sys.stderr)
for i in range(len(pod.servers)):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr)
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
def start_pod_server(self, args, pod):
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
for idx, cur_server in enumerate(pod.servers):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
self.cmds.append(cmd)
if idx == 0:
logger.info(
"Local server start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.servers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"
))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/serverlog.%d" % (args.log_dir, idx), "w")
self.log_fns.append(fn)
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_server.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
self.procs.append(tp)
def start_pod_worker(self, args, pod):
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
heter_device_num = 0
if args.heter_worker_device == "gpu":
heter_device_num = fluid.core.get_cuda_device_count()
elif args.heter_worker_device == "xpu":
heter_device_num = fluid.core.get_xpu_device_count()
for idx, cur_worker in enumerate(pod.workers):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
self.cmds.append(cmd)
if idx == 0:
logger.info(
"Local worker start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.workers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"
))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/workerlog.%d" % (args.log_dir, idx), "w")
self.log_fns.append(fn)
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_worker.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
self.procs.append(tp)
def start_pod_heter_worker(self, args, pod):
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
heter_device_num = 0
if args.heter_worker_device == "gpu":
heter_device_num = fluid.core.get_cuda_device_count()
elif args.heter_worker_device == "xpu":
heter_device_num = fluid.core.get_xpu_device_count()
assert heter_device_num != 0
for idx, cur_heter_worker in enumerate(pod.heter_workers):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"PADDLE_PORT": cur_heter_worker.endpoint.split(":")[1],
"TRAINING_ROLE": "HETER_TRAINER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_heter_worker.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
self.cmds.append(cmd)
if idx == 0:
logger.info(
"Local server start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.servers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"
))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/heterlog.%d" % (args.log_dir, idx), "w")
self.log_fns.append(fn)
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_heter_worker.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
self.procs.append(tp)
......@@ -13,6 +13,7 @@
from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase
from ..base.private_helper_function import wait_server_ready
from paddle.fluid import core
import subprocess
import re
......@@ -74,6 +75,8 @@ class ParameterServerOptimizer(MetaOptimizerBase):
_startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config)
compiled_config.set_origin_ps_main_program(_main)
compiled_config.set_origin_ps_startup_program(_startup)
# for heter program
if self.role_maker._is_heter_parameter_server_mode:
from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
......@@ -91,6 +94,16 @@ class ParameterServerOptimizer(MetaOptimizerBase):
else:
_main = worker.append_send_ops_pass(_main, compiled_config)
_startup = _startup
compiled_config.set_origin_ps_main_program(_main)
compiled_config.set_origin_ps_startup_program(_startup)
# for trainer wait server ready
wait_server_ready(self.role_maker._get_pserver_endpoints())
# for ps-heter mode, wait heter worker ready
if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker(
):
wait_server_ready(self.role_maker._get_heter_worker_endpoints())
return _main, _startup
......
......@@ -458,13 +458,13 @@ class ParameterServerRuntime(RuntimeBase):
def _save_distributed_persistables(self, executor, dirname, main_program):
dense_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=1)
recv_type=1, use_origin_program=True)
sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=2)
recv_type=2, use_origin_program=True)
distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
recv_type=3)
recv_type=3, use_origin_program=True)
recv_dense_varnames = self._save_dense_params(executor, dirname,
dense_ctx, main_program)
......@@ -516,7 +516,7 @@ class ParameterServerRuntime(RuntimeBase):
)
if main_program is None:
main_program = fluid.default_main_program()
main_program = self.compiled_strategy.get_origin_ps_main_program()
if isinstance(main_program, CompiledProgram):
raise TypeError(
......
......@@ -133,6 +133,8 @@ class CompileTimeStrategy(object):
self.origin_main_program = main_program
self.origin_startup_program = startup_program
self.origin_ps_main_program = main_program
self.origin_ps_startup_program = startup_program
self.strategy = strategy
self.role_maker = role_maker
......@@ -153,6 +155,11 @@ class CompileTimeStrategy(object):
self._build_var_distributed()
# for heter-ps save variables
self.origin_merged_variables_pairs = list(self.merged_variables_pairs)
self.origin_merged_dense_pairs = list(self.merged_dense_pairs)
self.origin_merged_sparse_pairs = list(self.merged_sparse_pairs)
def get_distributed_mode(self):
trainer = self.strategy.get_trainer_runtime_config()
return trainer.mode
......@@ -214,6 +221,18 @@ class CompileTimeStrategy(object):
def get_origin_startup_program(self):
return self.origin_startup_program
def set_origin_ps_main_program(self, program):
self.origin_ps_main_program = program
def set_origin_ps_startup_program(self, program):
self.origin_ps_startup_program = program
def get_origin_ps_main_program(self):
return self.origin_ps_main_program
def get_origin_ps_startup_program(self):
return self.origin_ps_startup_program
def get_sparse_varname_on_ps(self, is_distributed, endpoint=None):
if not endpoint:
endpoint = self.get_ps_endpoint()
......@@ -378,7 +397,9 @@ class CompileTimeStrategy(object):
send_ctx[name] = ctx
return send_ctx
def get_communicator_recv_context(self, recv_type=1):
def get_communicator_recv_context(self,
recv_type=1,
use_origin_program=False):
# recv_type
# 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL
distibuted_varnames = get_sparse_tablenames(self.origin_main_program,
......@@ -392,7 +413,8 @@ class CompileTimeStrategy(object):
sparse_recv_ctx = {}
distributed_recv_ctx = {}
for merged in self.merged_variables_pairs:
variables_pairs = self.merged_variables_pairs if not use_origin_program else self.origin_merged_variables_pairs
for merged in variables_pairs:
params = merged[0]
if params.merged_var.name in sparse_varnames:
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册