未验证 提交 d8e1e50a 编写于 作者: L lilong12 提交者: GitHub

[Cherry-pick] Fix bug in gloo that gloo initialization hangs (#29449)

* update, test=develop (#29331)
上级 49265879
...@@ -171,6 +171,7 @@ class Gloo(object): ...@@ -171,6 +171,7 @@ class Gloo(object):
def _init_http(self, ip, port, prefix, start_http_server, http_server_d): def _init_http(self, ip, port, prefix, start_http_server, http_server_d):
def __start_kv_server(http_server_d, size_d): def __start_kv_server(http_server_d, size_d):
print("start http_server: {}, {}".format(port, size_d))
from paddle.distributed.fleet.utils.http_server import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(port, size_d) http_server = KVServer(port, size_d)
http_server.start() http_server.start()
...@@ -181,11 +182,9 @@ class Gloo(object): ...@@ -181,11 +182,9 @@ class Gloo(object):
http_server.stop() http_server.stop()
def init_kv_server(http_server_d): def init_kv_server(http_server_d):
size_d = { worker_key = prefix + '_' + 'worker'
"trainer": self._worker_num, size_d = {worker_key: self._worker_num, }
"pserver": self._server_num, print("worker_key:{}, size: {}".format(worker_key, size_d))
"all": self._worker_num + self._server_num
}
http_server_d["running"] = True http_server_d["running"] = True
# child process for http server # child process for http server
...@@ -205,7 +204,7 @@ class Gloo(object): ...@@ -205,7 +204,7 @@ class Gloo(object):
gloo.set_iface(self._iface) gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds, gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds) self._run_timeout_seconds)
gloo.set_http_store(ip, port, role) gloo.set_http_store(ip, port, 'worker')
ep = ":".join([ip, str(port)]) ep = ":".join([ip, str(port)])
wait_server_ready([ep]) wait_server_ready([ep])
gloo.init() gloo.init()
...@@ -214,6 +213,7 @@ class Gloo(object): ...@@ -214,6 +213,7 @@ class Gloo(object):
port = int(port) port = int(port)
if start_http_server: if start_http_server:
print("to start http_server")
http_server = init_kv_server(http_server_d) http_server = init_kv_server(http_server_d)
if self._role == Role.WORKER: if self._role == Role.WORKER:
......
...@@ -112,8 +112,8 @@ class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): ...@@ -112,8 +112,8 @@ class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
_, scope, key = paths _, scope, key = paths
with self.server.delete_kv_lock: with self.server.delete_kv_lock:
if self.server.delete_kv.get(scope) is None: if self.server.delete_kv.get(scope) is None:
self.server.delete_kv[scope] = [] self.server.delete_kv[scope] = set()
self.server.delete_kv[scope].append(key) self.server.delete_kv[scope].add(key)
self.send_status_code(200) self.send_status_code(200)
_http_server_logger.info(log_str) _http_server_logger.info(log_str)
...@@ -151,7 +151,7 @@ class KVHTTPServer(HTTPServer, object): ...@@ -151,7 +151,7 @@ class KVHTTPServer(HTTPServer, object):
""" """
ret = 0 ret = 0
with self.delete_kv_lock: with self.delete_kv_lock:
ret = self.delete_kv.get(key, 0) ret = len(self.delete_kv.get(key, set()))
return ret return ret
...@@ -164,7 +164,7 @@ class KVServer: ...@@ -164,7 +164,7 @@ class KVServer:
"""Init.""" """Init."""
self.http_server = KVHTTPServer(port, KVHandler) self.http_server = KVHTTPServer(port, KVHandler)
self.listen_thread = None self.listen_thread = None
self.size = {} self.size = size
def start(self): def start(self):
""" """
......
...@@ -44,11 +44,11 @@ def _get_global_parallel_env(): ...@@ -44,11 +44,11 @@ def _get_global_parallel_env():
return _global_parallel_env return _global_parallel_env
def _start_kv_server(port, http_server_d): def _start_kv_server(port, http_server_d, size):
from paddle.distributed.fleet.utils.http_server import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port)) http_server = KVServer(int(port), size=size)
http_server.start() http_server.start()
wait_seconds = 5 wait_seconds = 3
while http_server_d.get("running", False) or not http_server.should_stop(): while http_server_d.get("running", False) or not http_server.should_stop():
time.sleep(wait_seconds) time.sleep(wait_seconds)
http_server.stop() http_server.stop()
...@@ -149,8 +149,11 @@ def init_parallel_env(): ...@@ -149,8 +149,11 @@ def init_parallel_env():
http_server_d = manager.dict() http_server_d = manager.dict()
http_server_d["running"] = False http_server_d["running"] = False
if parallel_env.rank == 0: if parallel_env.rank == 0:
# The scope for worker used by http server is '_worker'
size = {'_worker': parallel_env.world_size}
http_server = Process( http_server = Process(
target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d)) target=_start_kv_server,
args=(int(ep_rank_0[1]), http_server_d, size))
http_server.daemon = True http_server.daemon = True
http_server_d["running"] = True http_server_d["running"] = True
http_server.start() http_server.start()
......
...@@ -274,7 +274,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -274,7 +274,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
print("skip gloo UT on MacOS/Win") print("skip gloo UT on MacOS/Win")
return return
os.environ["TRAINING_ROLE"] = "PSERVER" os.environ["TRAINING_ROLE"] = "WORKER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001" os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1" os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001" os.environ["PADDLE_PORT"] = "36001"
...@@ -284,7 +284,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -284,7 +284,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3" os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019" os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker(is_collecitve=True)
role._generate_role() role._generate_role()
import time import time
time.sleep(3) time.sleep(3)
...@@ -532,7 +532,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -532,7 +532,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
print("skip gloo UT on MacOS/Win") print("skip gloo UT on MacOS/Win")
return return
os.environ["TRAINING_ROLE"] = "PSERVER" os.environ["TRAINING_ROLE"] = "WORKER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001" os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1" os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001" os.environ["PADDLE_PORT"] = "36001"
...@@ -542,7 +542,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -542,7 +542,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3" os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019" os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker(is_collective=True)
role._generate_role() role._generate_role()
import time import time
time.sleep(3) time.sleep(3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册