未验证 提交 d8e1e50a 编写于 作者: L lilong12 提交者: GitHub

[Cherry-pick] Fix bug in gloo that gloo initialization hangs (#29449)

* update, test=develop (#29331)
上级 49265879
......@@ -171,6 +171,7 @@ class Gloo(object):
def _init_http(self, ip, port, prefix, start_http_server, http_server_d):
def __start_kv_server(http_server_d, size_d):
print("start http_server: {}, {}".format(port, size_d))
from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(port, size_d)
http_server.start()
......@@ -181,11 +182,9 @@ class Gloo(object):
http_server.stop()
def init_kv_server(http_server_d):
size_d = {
"trainer": self._worker_num,
"pserver": self._server_num,
"all": self._worker_num + self._server_num
}
worker_key = prefix + '_' + 'worker'
size_d = {worker_key: self._worker_num, }
print("worker_key:{}, size: {}".format(worker_key, size_d))
http_server_d["running"] = True
# child process for http server
......@@ -205,7 +204,7 @@ class Gloo(object):
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_http_store(ip, port, role)
gloo.set_http_store(ip, port, 'worker')
ep = ":".join([ip, str(port)])
wait_server_ready([ep])
gloo.init()
......@@ -214,6 +213,7 @@ class Gloo(object):
port = int(port)
if start_http_server:
print("to start http_server")
http_server = init_kv_server(http_server_d)
if self._role == Role.WORKER:
......
......@@ -112,8 +112,8 @@ class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
_, scope, key = paths
with self.server.delete_kv_lock:
if self.server.delete_kv.get(scope) is None:
self.server.delete_kv[scope] = []
self.server.delete_kv[scope].append(key)
self.server.delete_kv[scope] = set()
self.server.delete_kv[scope].add(key)
self.send_status_code(200)
_http_server_logger.info(log_str)
......@@ -151,7 +151,7 @@ class KVHTTPServer(HTTPServer, object):
"""
ret = 0
with self.delete_kv_lock:
ret = self.delete_kv.get(key, 0)
ret = len(self.delete_kv.get(key, set()))
return ret
......@@ -164,7 +164,7 @@ class KVServer:
"""Init."""
self.http_server = KVHTTPServer(port, KVHandler)
self.listen_thread = None
self.size = {}
self.size = size
def start(self):
"""
......
......@@ -44,11 +44,11 @@ def _get_global_parallel_env():
return _global_parallel_env
def _start_kv_server(port, http_server_d):
def _start_kv_server(port, http_server_d, size):
from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port))
http_server = KVServer(int(port), size=size)
http_server.start()
wait_seconds = 5
wait_seconds = 3
while http_server_d.get("running", False) or not http_server.should_stop():
time.sleep(wait_seconds)
http_server.stop()
......@@ -149,8 +149,11 @@ def init_parallel_env():
http_server_d = manager.dict()
http_server_d["running"] = False
if parallel_env.rank == 0:
# The scope for worker used by http server is '_worker'
size = {'_worker': parallel_env.world_size}
http_server = Process(
target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d))
target=_start_kv_server,
args=(int(ep_rank_0[1]), http_server_d, size))
http_server.daemon = True
http_server_d["running"] = True
http_server.start()
......
......@@ -274,7 +274,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
print("skip gloo UT on MacOS/Win")
return
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["TRAINING_ROLE"] = "WORKER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
......@@ -284,7 +284,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker()
role = role_maker.PaddleCloudRoleMaker(is_collecitve=True)
role._generate_role()
import time
time.sleep(3)
......@@ -532,7 +532,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
print("skip gloo UT on MacOS/Win")
return
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["TRAINING_ROLE"] = "WORKER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
......@@ -542,7 +542,7 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_ENDPOINT"] = "127.0.0.1:30019"
role = role_maker.PaddleCloudRoleMaker()
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
role._generate_role()
import time
time.sleep(3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册