未验证 提交 30562e37 编写于 作者: G guru4elephant 提交者: GitHub

refine launch_ps and role_maker (#18795)

refine launch_ps and role_maker
上级 292dfbce
...@@ -48,6 +48,9 @@ def parse_args(): ...@@ -48,6 +48,9 @@ def parse_args():
default=True, default=True,
help="Print the config or not") help="Print the config or not")
parser.add_argument(
"--endpoints", type=str, default="", help="User defined endpoints")
parser.add_argument( parser.add_argument(
"--worker_num", type=int, default=2, help="number of workers") "--worker_num", type=int, default=2, help="number of workers")
...@@ -87,13 +90,23 @@ def start_procs(args): ...@@ -87,13 +90,23 @@ def start_procs(args):
cmds = [] cmds = []
log_fns = [] log_fns = []
ports = range(start_port, start_port + server_num, 1) ports = range(start_port, start_port + server_num, 1)
endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports]) default_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints = ""
if args.endpoints == "":
user_endpoints = default_endpoints
else:
user_endpoints = args.endpoints
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
for i in range(server_num): for i in range(server_num):
current_env.update({ current_env.update({
"TRAINER_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(server_num),
"CURRENT_ID": str(i), "PADDLE_PORT": ",".join(user_endpoints_port),
"ENDPOINTS": endpoints, #"POD_IP": user_endpoints_ips[i],
"TRAINING_ROLE": "PSERVER" "CURRENT_ENDPOINT":
user_endpoints_ips[i] + ":" + user_endpoints_port[i],
"PADDLE_PSERVERS": ",".join(user_endpoints_ips),
"PADDLE_TRAINING_ROLE": "PSERVER"
}) })
cmd = [sys.executable, "-u", args.training_script cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args ] + args.training_script_args
...@@ -110,10 +123,11 @@ def start_procs(args): ...@@ -110,10 +123,11 @@ def start_procs(args):
for i in range(worker_num): for i in range(worker_num):
current_env.update({ current_env.update({
"ENDPOINTS": endpoints, "PADDLE_PSERVERS": ",".join(user_endpoints_ips),
"TRAINER_NUM": str(worker_num), "PADDLE_PORT": ",".join(user_endpoints_port),
"TRAINING_ROLE": "TRAINER", "PADDLE_TRAINERS_NUM": str(worker_num),
"CURRENT_ID": str(i) "PADDLE_TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
}) })
cmd = [sys.executable, "-u", args.training_script cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args ] + args.training_script_args
......
...@@ -334,25 +334,41 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -334,25 +334,41 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def generate_role(self): def generate_role(self):
if not self._role_is_generated: if not self._role_is_generated:
if not self._is_collective: if not self._is_collective:
self.port = os.getenv("PADDLE_PORT", "6174") self.port = os.getenv("PADDLE_PORT",
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "") "6174") # port of current server
self.pserver_ips = os.getenv("PADDLE_PSERVERS",
"") # ip of server
if "," in self.port:
ports = self.port.split(",")
else:
ports = [self.port for i in self.pserver_ips.split(",")]
eplist = [] eplist = []
for ip in self.pserver_ips.split(","): # note that, we usually assign the same port to different ips
eplist.append(':'.join([ip, self.port])) # if we run parameter server training in local mode
# port should be different in environment variables
for i, ip in enumerate(self.pserver_ips.split(",")):
eplist.append(':'.join([ip, ports[i]]))
self.endpoints = ",".join(eplist) self.endpoints = ",".join(eplist)
self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP", # ip of current node, either a worker or a pserver
"localhost") + ":" + self.port current_ip = os.getenv("POD_IP", "")
self.role = os.getenv("TRAINING_ROLE", "TRAINER") if current_ip == "":
self._current_endpoint = os.getenv("CURRENT_ENDPOINT")
else:
self._current_endpoint = current_ip + ports[0]
self.role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
# for trainer, only POD_IP and current trainer id is needed
# we usually do not need to know other trainer ips
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist self.eplist = eplist
self.endpoints = self.endpoints.split(",") self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER": if self.role.upper() == "PSERVER":
# current endpoint index among all pservers
self._current_id = self.endpoints.index( self._current_id = self.endpoints.index(
self.current_endpoint) self._current_endpoint)
self._role = Role.SERVER self._role = Role.SERVER
else: else:
self._current_id = self.trainer_id self._current_id = self.trainer_id
...@@ -369,6 +385,11 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -369,6 +385,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._num_trainers = len(self._worker_endpoints) self._num_trainers = len(self._worker_endpoints)
self._role_is_generated = True self._role_is_generated = True
def get_pserver_endpoints(self):
if not self._role_is_generated:
self.generate_role()
return self._server_endpoints
def is_worker(self): def is_worker(self):
if not self._role_is_generated: if not self._role_is_generated:
self.generate_role() self.generate_role()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册