未验证 提交 ff399fd7 编写于 作者: G guru4elephant 提交者: GitHub

fix paddle cloud role maker bug (#18269)

* fix paddle cloud role maker bug
上级 412951d7
...@@ -19,6 +19,8 @@ __all__ = [ ...@@ -19,6 +19,8 @@ __all__ = [
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker' 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
] ]
import os
class Role: class Role:
WORKER = 1 WORKER = 1
...@@ -295,45 +297,62 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -295,45 +297,62 @@ class MPISymetricRoleMaker(MPIRoleMaker):
class PaddleCloudRoleMaker(RoleMakerBase): class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self): def __init__(self):
super(PaddleCloudRoleMaker, self).__init__() super(PaddleCloudRoleMaker, self).__init__()
self._role_is_generated = False
def generate_role(self): def generate_role(self):
if not self._role_is_generated: if not self._role_is_generated:
self.port = os.getenv("PADDLE_PORT", "6174") self.port = os.getenv("PADDLE_PORT", "6174")
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "") self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = [] eplist = []
for ip in pserver_ips.split(","): for ip in self.pserver_ips.split(","):
eplist.append(':'.join([ip, port])) eplist.append(':'.join([ip, self.port]))
self.endpoints = ",".join(eplist) self.endpoints = ",".join(eplist)
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP", self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + port "localhost") + ":" + self.port
self.role = os.getenv("TRAINING_ROLE", "TRAINER") self.role = os.getenv("TRAINING_ROLE", "TRAINER")
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist self.eplist = eplist
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",") self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints
if self.role.upper() == "PSERVER": if self.role.upper() == "PSERVER":
self.current_id = self.endpoints.index(self.current_endpoint) self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER
else: else:
self.current_id = self.trainer_id self._current_id = self.trainer_id
self._role = Role.WORKER
self._role_is_generated = True self._role_is_generated = True
def is_wokrer(self): def is_worker(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.WORKER return self._role == Role.WORKER
def is_server(self): def is_server(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.SERVER return self._role == Role.SERVER
def is_first_worker(self): def is_first_worker(self):
if not self._role_is_generated:
self.generate_role()
return self._role == Role.WORKER and self._current_id == 0 return self._role == Role.WORKER and self._current_id == 0
def worker_index(self): def worker_index(self):
if not self._role_is_generated:
self.generate_role()
return self._current_id return self._current_id
def server_index(self): def server_index(self):
if not self._role_is_generated:
self.generate_role()
return self._current_id return self._current_id
def worker_num(self): def worker_num(self):
return self._worker_num if not self._role_is_generated:
self.generate_role()
return self._trainers
class UserDefinedRoleMaker(RoleMakerBase): class UserDefinedRoleMaker(RoleMakerBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册