未验证 提交 58f3e1ba 编写于 作者: G guru4elephant 提交者: GitHub

add paddle cloud role maker for customized usage, note this is only for...

add paddle cloud role maker for customized usage, note this is only for industrial users that have cloud environment pre-configuration (#18121)

add paddle cloud role maker for specific cloud usage. This pr will simplifies user's configuration in distributed training.
上级 80d2e66f
......@@ -188,17 +188,7 @@ class Fleet(object):
if role_maker and not isinstance(role_maker, RoleMakerBase):
raise ValueError("role_maker must be an instance of RoleMakerBase")
if isinstance(role_maker, MPISymetricRoleMaker):
self._role_maker = role_maker
self._role_maker.generate_role()
elif isinstance(role_maker, UserDefinedRoleMaker):
self._role_maker = role_maker
else:
raise ValueError(
"role_maker must be an instance of UserDefinedRoleMaker/MPISymetricRoleMaker"
)
self._role_maker.generate_role()
self._is_initialized = True
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
__all__ = [
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
'UserDefinedCollectiveRoleMaker'
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
]
......@@ -292,6 +292,50 @@ class MPISymetricRoleMaker(MPIRoleMaker):
self._role_is_generated = True
class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self):
super(PaddleCloudRoleMaker, self).__init__()
def generate_role(self):
if not self._role_is_generated:
self.port = os.getenv("PADDLE_PORT", "6174")
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
self.endpoints = ",".join(eplist)
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + port
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist
self.endpoints = self.endpoints.split(",")
if self.role.upper() == "PSERVER":
self.current_id = self.endpoints.index(self.current_endpoint)
else:
self.current_id = self.trainer_id
self._role_is_generated = True
def is_wokrer(self):
return self._role == Role.WORKER
def is_server(self):
return self._role == Role.SERVER
def is_first_worker(self):
return self._role == Role.WORKER and self._current_id == 0
def worker_index(self):
return self._current_id
def server_index(self):
return self._current_id
def worker_num(self):
return self._worker_num
class UserDefinedRoleMaker(RoleMakerBase):
def __init__(self,
current_id=0,
......@@ -329,6 +373,9 @@ class UserDefinedRoleMaker(RoleMakerBase):
else:
self._server_endpoints = server_endpoints
def generate_role(self):
self._role_is_generated = True
def is_worker(self):
return self._role == Role.WORKER
......@@ -369,6 +416,9 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase):
self._worker_endpoints = worker_endpoints
self._worker_num = len(self._worker_endpoints)
def generate_role(self):
self._role_is_generated = True
def is_worker(self):
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册