提交 0fad63a3 编写于 作者: T tangwei12 提交者: guru4elephant

fix paddle cloud role maker bug (#18269) (#18311)

* fix paddle cloud role maker bug
上级 f6432604
......@@ -19,6 +19,8 @@ __all__ = [
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
]
import os
class Role:
WORKER = 1
......@@ -295,45 +297,62 @@ class MPISymetricRoleMaker(MPIRoleMaker):
class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self):
super(PaddleCloudRoleMaker, self).__init__()
self._role_is_generated = False
def generate_role(self):
if not self._role_is_generated:
self.port = os.getenv("PADDLE_PORT", "6174")
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
self.endpoints = ",".join(eplist)
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + port
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
for ip in self.pserver_ips.split(","):
eplist.append(':'.join([ip, self.port]))
self.endpoints = ",".join(eplist)
self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + self.port
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints
if self.role.upper() == "PSERVER":
self.current_id = self.endpoints.index(self.current_endpoint)
self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER
else:
self.current_id = self.trainer_id
self._current_id = self.trainer_id
self._role = Role.WORKER
self._role_is_generated = True
def is_wokrer(self):
def is_worker(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.WORKER
def is_server(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.SERVER
def is_first_worker(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.WORKER and self._current_id == 0
def worker_index(self):
if not self._role_is_generated:
self.generate_role()
return self._current_id
def server_index(self):
if not self._role_is_generated:
self.generate_role()
return self._current_id
def worker_num(self):
return self._worker_num
if not self._role_is_generated:
self.generate_role()
return self._trainers
class UserDefinedRoleMaker(RoleMakerBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册