提交 dcff115e 编写于 作者: M Megvii Engine Team

fix(distributed/launcher): fetch data early to avoid dead lock in launcher

GitOrigin-RevId: 9abcc956efc3e208a206d015e6e8d1a28b7d43cd
上级 7e22e9f0
...@@ -114,6 +114,7 @@ class launcher: ...@@ -114,6 +114,7 @@ class launcher:
procs[dev].terminate() procs[dev].terminate()
devs.clear() devs.clear()
result_count = 0
while len(devs) > 0: while len(devs) > 0:
left = [] left = []
# check all processes in one second # check all processes in one second
...@@ -129,11 +130,21 @@ class launcher: ...@@ -129,11 +130,21 @@ class launcher:
), "subprocess {} exit with code {}".format(dev + self.rank_start, code) ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
if code == None: if code == None:
left.append(dev) left.append(dev)
elif queue.empty():
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) # DO NOT delete it, multiprocess.Queue has small buffer
else: # fetch data early to avoid dead lock
if not queue.empty():
result_count += 1
dev, ret = queue.get_nowait() dev, ret = queue.get_nowait()
results[dev] = ret results[dev] = ret
devs = left devs = left
while not queue.empty():
result_count += 1
dev, ret = queue.get_nowait()
results[dev] = ret
if result_count < self.n_gpus:
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
return results return results
...@@ -199,13 +199,14 @@ def test_param_pack_concat(): ...@@ -199,13 +199,14 @@ def test_param_pack_concat():
@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) @pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"])
@pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"])
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_collect_results(early_return): def test_collect_results(early_return, output_size):
@dist.launcher @dist.launcher
def worker(): def worker():
if early_return: if early_return:
exit(0) exit(0)
return (dist.get_rank(), dist.get_world_size()) return [dist.get_rank()] * output_size
results = worker() results = worker()
world_size = len(results) world_size = len(results)
...@@ -213,6 +214,6 @@ def test_collect_results(early_return): ...@@ -213,6 +214,6 @@ def test_collect_results(early_return):
expects = ( expects = (
[None] * world_size [None] * world_size
if early_return if early_return
else [(dev, world_size) for dev in range(world_size)] else [[dev] * output_size for dev in range(world_size)]
) )
assert results == expects assert results == expects
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册