diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index 7c707a1f44853b93f6fe7c6b2be8c3530c532b6c..6282e82cfc1d5336efd7bb0bfe1a8f000c4a08be 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -188,17 +188,7 @@ class Fleet(object): if role_maker and not isinstance(role_maker, RoleMakerBase): raise ValueError("role_maker must be an instance of RoleMakerBase") - if isinstance(role_maker, MPISymetricRoleMaker): - self._role_maker = role_maker - self._role_maker.generate_role() - - elif isinstance(role_maker, UserDefinedRoleMaker): - self._role_maker = role_maker - - else: - raise ValueError( - "role_maker must be an instance of UserDefinedRoleMaker/MPISymetricRoleMaker" - ) + self._role_maker.generate_role() self._is_initialized = True diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index dc4d98cf61ccb14912a2d0a13a3819759b4bcd5d..ae6768f8f568f6877c591134d9766d6542f956e7 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -16,7 +16,7 @@ from __future__ import print_function __all__ = [ 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker', - 'UserDefinedCollectiveRoleMaker' + 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker' ] @@ -292,6 +292,50 @@ class MPISymetricRoleMaker(MPIRoleMaker): self._role_is_generated = True +class PaddleCloudRoleMaker(RoleMakerBase): + def __init__(self): + super(PaddleCloudRoleMaker, self).__init__() + + def generate_role(self): + if not self._role_is_generated: + self.port = os.getenv("PADDLE_PORT", "6174") + self.pserver_ips = os.getenv("PADDLE_PSERVERS", "") + eplist = [] + for ip in pserver_ips.split(","): + eplist.append(':'.join([ip, port])) + self.endpoints = ",".join(eplist) + self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + self.current_endpoint = os.getenv("POD_IP", + "localhost") + ":" + port + self.role = os.getenv("TRAINING_ROLE", "TRAINER") + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self.eplist = eplist + self.endpoints = self.endpoints.split(",") + if self.role.upper() == "PSERVER": + self.current_id = self.endpoints.index(self.current_endpoint) + else: + self.current_id = self.trainer_id + self._role_is_generated = True + + def is_wokrer(self): + return self._role == Role.WORKER + + def is_server(self): + return self._role == Role.SERVER + + def is_first_worker(self): + return self._role == Role.WORKER and self._current_id == 0 + + def worker_index(self): + return self._current_id + + def server_index(self): + return self._current_id + + def worker_num(self): + return self._worker_num + + class UserDefinedRoleMaker(RoleMakerBase): def __init__(self, current_id=0, @@ -329,6 +373,9 @@ class UserDefinedRoleMaker(RoleMakerBase): else: self._server_endpoints = server_endpoints + def generate_role(self): + self._role_is_generated = True + def is_worker(self): return self._role == Role.WORKER @@ -369,6 +416,9 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase): self._worker_endpoints = worker_endpoints self._worker_num = len(self._worker_endpoints) + def generate_role(self): + self._role_is_generated = True + def is_worker(self): return True