未验证 提交 30562e37 编写于 作者: G guru4elephant 提交者: GitHub

refine launch_ps and role_maker (#18795)

refine launch_ps and role_maker
上级 292dfbce
......@@ -48,6 +48,9 @@ def parse_args():
default=True,
help="Print the config or not")
parser.add_argument(
"--endpoints", type=str, default="", help="User defined endpoints")
parser.add_argument(
"--worker_num", type=int, default=2, help="number of workers")
......@@ -87,13 +90,23 @@ def start_procs(args):
cmds = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
default_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints = ""
if args.endpoints == "":
user_endpoints = default_endpoints
else:
user_endpoints = args.endpoints
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
for i in range(server_num):
current_env.update({
"TRAINER_NUM": str(worker_num),
"CURRENT_ID": str(i),
"ENDPOINTS": endpoints,
"TRAINING_ROLE": "PSERVER"
"PADDLE_TRAINERS_NUM": str(server_num),
"PADDLE_PORT": ",".join(user_endpoints_port),
#"POD_IP": user_endpoints_ips[i],
"CURRENT_ENDPOINT":
user_endpoints_ips[i] + ":" + user_endpoints_port[i],
"PADDLE_PSERVERS": ",".join(user_endpoints_ips),
"PADDLE_TRAINING_ROLE": "PSERVER"
})
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
......@@ -110,10 +123,11 @@ def start_procs(args):
for i in range(worker_num):
current_env.update({
"ENDPOINTS": endpoints,
"TRAINER_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"CURRENT_ID": str(i)
"PADDLE_PSERVERS": ",".join(user_endpoints_ips),
"PADDLE_PORT": ",".join(user_endpoints_port),
"PADDLE_TRAINERS_NUM": str(worker_num),
"PADDLE_TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
})
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
......
......@@ -334,25 +334,41 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def generate_role(self):
if not self._role_is_generated:
if not self._is_collective:
self.port = os.getenv("PADDLE_PORT", "6174")
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
self.port = os.getenv("PADDLE_PORT",
"6174") # port of current server
self.pserver_ips = os.getenv("PADDLE_PSERVERS",
"") # ip of server
if "," in self.port:
ports = self.port.split(",")
else:
ports = [self.port for i in self.pserver_ips.split(",")]
eplist = []
for ip in self.pserver_ips.split(","):
eplist.append(':'.join([ip, self.port]))
# note that, we usually assign the same port to different ips
# if we run parameter server training in local mode
# port should be different in environment variables
for i, ip in enumerate(self.pserver_ips.split(",")):
eplist.append(':'.join([ip, ports[i]]))
self.endpoints = ",".join(eplist)
self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + self.port
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
# ip of current node, either a worker or a pserver
current_ip = os.getenv("POD_IP", "")
if current_ip == "":
self._current_endpoint = os.getenv("CURRENT_ENDPOINT")
else:
self._current_endpoint = current_ip + ports[0]
self.role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
# for trainer, only POD_IP and current trainer id is needed
# we usually do not need to know other trainer ips
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist
self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER":
# current endpoint index among all pservers
self._current_id = self.endpoints.index(
self.current_endpoint)
self._current_endpoint)
self._role = Role.SERVER
else:
self._current_id = self.trainer_id
......@@ -369,6 +385,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._num_trainers = len(self._worker_endpoints)
self._role_is_generated = True
def get_pserver_endpoints(self):
if not self._role_is_generated:
self.generate_role()
return self._server_endpoints
def is_worker(self):
if not self._role_is_generated:
self.generate_role()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册