未验证 提交 58f3e1ba 编写于 作者: G guru4elephant 提交者: GitHub

add paddle cloud role maker for customized usage, note this is only for...

add paddle cloud role maker for customized usage, note this is only for industrial users that have cloud environment pre-configuration (#18121)

add paddle cloud role maker for specific cloud usage. This pr will simplifies user's configuration in distributed training.
上级 80d2e66f
...@@ -188,18 +188,8 @@ class Fleet(object): ...@@ -188,18 +188,8 @@ class Fleet(object):
if role_maker and not isinstance(role_maker, RoleMakerBase): if role_maker and not isinstance(role_maker, RoleMakerBase):
raise ValueError("role_maker must be an instance of RoleMakerBase") raise ValueError("role_maker must be an instance of RoleMakerBase")
if isinstance(role_maker, MPISymetricRoleMaker):
self._role_maker = role_maker
self._role_maker.generate_role() self._role_maker.generate_role()
elif isinstance(role_maker, UserDefinedRoleMaker):
self._role_maker = role_maker
else:
raise ValueError(
"role_maker must be an instance of UserDefinedRoleMaker/MPISymetricRoleMaker"
)
self._is_initialized = True self._is_initialized = True
@abc.abstractmethod @abc.abstractmethod
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
__all__ = [ __all__ = [
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker', 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
'UserDefinedCollectiveRoleMaker' 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
] ]
...@@ -292,6 +292,50 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -292,6 +292,50 @@ class MPISymetricRoleMaker(MPIRoleMaker):
self._role_is_generated = True self._role_is_generated = True
class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self):
super(PaddleCloudRoleMaker, self).__init__()
def generate_role(self):
if not self._role_is_generated:
self.port = os.getenv("PADDLE_PORT", "6174")
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
self.endpoints = ",".join(eplist)
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
self.current_endpoint = os.getenv("POD_IP",
"localhost") + ":" + port
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.eplist = eplist
self.endpoints = self.endpoints.split(",")
if self.role.upper() == "PSERVER":
self.current_id = self.endpoints.index(self.current_endpoint)
else:
self.current_id = self.trainer_id
self._role_is_generated = True
def is_wokrer(self):
return self._role == Role.WORKER
def is_server(self):
return self._role == Role.SERVER
def is_first_worker(self):
return self._role == Role.WORKER and self._current_id == 0
def worker_index(self):
return self._current_id
def server_index(self):
return self._current_id
def worker_num(self):
return self._worker_num
class UserDefinedRoleMaker(RoleMakerBase): class UserDefinedRoleMaker(RoleMakerBase):
def __init__(self, def __init__(self,
current_id=0, current_id=0,
...@@ -329,6 +373,9 @@ class UserDefinedRoleMaker(RoleMakerBase): ...@@ -329,6 +373,9 @@ class UserDefinedRoleMaker(RoleMakerBase):
else: else:
self._server_endpoints = server_endpoints self._server_endpoints = server_endpoints
def generate_role(self):
self._role_is_generated = True
def is_worker(self): def is_worker(self):
return self._role == Role.WORKER return self._role == Role.WORKER
...@@ -369,6 +416,9 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase): ...@@ -369,6 +416,9 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase):
self._worker_endpoints = worker_endpoints self._worker_endpoints = worker_endpoints
self._worker_num = len(self._worker_endpoints) self._worker_num = len(self._worker_endpoints)
def generate_role(self):
self._role_is_generated = True
def is_worker(self): def is_worker(self):
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册