diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 68997246959c554ef4fb8c8bef72b7d601b1c4d7..1eff52fdb9485828fab1989d26381675f014c1e5 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -114,6 +114,7 @@ class launcher: procs[dev].terminate() devs.clear() + result_count = 0 while len(devs) > 0: left = [] # check all processes in one second @@ -129,11 +130,21 @@ class launcher: ), "subprocess {} exit with code {}".format(dev + self.rank_start, code) if code == None: left.append(dev) - elif queue.empty(): - get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) - else: + + # DO NOT delete it, multiprocess.Queue has small buffer + # fetch data early to avoid dead lock + if not queue.empty(): + result_count += 1 dev, ret = queue.get_nowait() results[dev] = ret devs = left + while not queue.empty(): + result_count += 1 + dev, ret = queue.get_nowait() + results[dev] = ret + + if result_count < self.n_gpus: + get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) + return results diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index d62136536627bf4e9ad7f2c82f9a36e04c1295bd..3973dcddbdae24c7a78bb73bc61348e2f9e53548 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -199,13 +199,14 @@ def test_param_pack_concat(): @pytest.mark.require_ngpu(2) @pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) +@pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"]) @pytest.mark.isolated_distributed -def test_collect_results(early_return): +def test_collect_results(early_return, output_size): @dist.launcher def worker(): if early_return: exit(0) - return (dist.get_rank(), dist.get_world_size()) + return [dist.get_rank()] * output_size results = worker() world_size = len(results) @@ -213,6 +214,6 @@ def test_collect_results(early_return): expects = ( [None] * world_size if early_return - else [(dev, world_size) for dev in range(world_size)] + else [[dev] * output_size for dev in range(world_size)] ) assert results == expects