From ff399fd720be4a2d03ddb03d7e889962dadb5c49 Mon Sep 17 00:00:00 2001 From: guru4elephant <35550832+guru4elephant@users.noreply.github.com> Date: Sun, 23 Jun 2019 17:08:03 +0800 Subject: [PATCH] fix paddle cloud role maker bug (#18269) * fix paddle cloud role maker bug --- .../fluid/incubate/fleet/base/role_maker.py | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index ae6768f8f56..a5802ac1fe7 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -19,6 +19,8 @@ __all__ = [ 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker' ] +import os + class Role: WORKER = 1 @@ -295,45 +297,62 @@ class MPISymetricRoleMaker(MPIRoleMaker): class PaddleCloudRoleMaker(RoleMakerBase): def __init__(self): super(PaddleCloudRoleMaker, self).__init__() + self._role_is_generated = False def generate_role(self): if not self._role_is_generated: self.port = os.getenv("PADDLE_PORT", "6174") self.pserver_ips = os.getenv("PADDLE_PSERVERS", "") eplist = [] - for ip in pserver_ips.split(","): - eplist.append(':'.join([ip, port])) - self.endpoints = ",".join(eplist) - self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) - self.current_endpoint = os.getenv("POD_IP", - "localhost") + ":" + port - self.role = os.getenv("TRAINING_ROLE", "TRAINER") - self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + for ip in self.pserver_ips.split(","): + eplist.append(':'.join([ip, self.port])) + self.endpoints = ",".join(eplist) + self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + self.current_endpoint = os.getenv("POD_IP", + "localhost") + ":" + self.port + self.role = os.getenv("TRAINING_ROLE", "TRAINER") + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self.eplist = eplist + print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints) self.endpoints = self.endpoints.split(",") + self._server_endpoints = self.endpoints if self.role.upper() == "PSERVER": - self.current_id = self.endpoints.index(self.current_endpoint) + self._current_id = self.endpoints.index(self.current_endpoint) + self._role = Role.SERVER else: - self.current_id = self.trainer_id + self._current_id = self.trainer_id + self._role = Role.WORKER self._role_is_generated = True - def is_wokrer(self): + def is_worker(self): + if not self._role_is_generated: + self.generate_role() return self._role == Role.WORKER def is_server(self): + if not self._role_is_generated: + self.generate_role() return self._role == Role.SERVER def is_first_worker(self): + if not self._role_is_generated: + self.generate_role() return self._role == Role.WORKER and self._current_id == 0 def worker_index(self): + if not self._role_is_generated: + self.generate_role() return self._current_id def server_index(self): + if not self._role_is_generated: + self.generate_role() return self._current_id def worker_num(self): - return self._worker_num + if not self._role_is_generated: + self.generate_role() + return self._trainers class UserDefinedRoleMaker(RoleMakerBase): -- GitLab