From 17323dbd73773e7cd5de89803360c74f6dc937e5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Feb 2021 14:26:21 +0800 Subject: [PATCH] feat(dist): collect return values in dist.launcher GitOrigin-RevId: 519e768ce916a28ea3910f4d1437ea443be4472f --- .../python/megengine/distributed/launcher.py | 38 +++++++++++++++++-- .../test/unit/distributed/test_distributed.py | 21 ++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 2ae38ed7..68997246 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -8,15 +8,30 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import multiprocessing as mp +import queue from ..core._imperative_rt.core2 import sync +from ..logger import get_logger from .group import group_barrier, init_process_group from .helper import get_device_count_by_fork from .server import Client, Server +WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( + "subprocess exited with code 0 but did not return a value" +) + def _run_wrapped( - func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs + func, + is_multimachine, + master_ip, + port, + world_size, + rank, + dev, + args, + kwargs, + queue: mp.Queue, ): """Init distributed process group and run wrapped function.""" init_process_group( @@ -24,7 +39,8 @@ def _run_wrapped( ) if is_multimachine: group_barrier() - func(*args, **kwargs) + ret = func(*args, **kwargs) + queue.put((dev, ret)) sync() if is_multimachine: group_barrier() @@ -70,6 +86,8 @@ class launcher: def __call__(self, *args, **kwargs): procs = [] + queue = mp.Queue(self.n_gpus) + results = [None] * self.n_gpus for dev in range(self.n_gpus): p = mp.Process( target=_run_wrapped, @@ -83,6 +101,7 @@ class launcher: dev, args, kwargs, + queue, ), ) p.start() @@ -90,6 +109,11 @@ class launcher: devs = list(range(self.n_gpus)) + def terminate(): + for dev in devs: + procs[dev].terminate() + devs.clear() + while len(devs) > 0: left = [] # check all processes in one second @@ -99,11 +123,17 @@ class launcher: code = procs[dev].exitcode # terminate processes if one of them has failed if code != 0 and code != None: - for i in devs: - procs[i].terminate() + terminate() assert ( code == 0 or code == None ), "subprocess {} exit with code {}".format(dev + self.rank_start, code) if code == None: left.append(dev) + elif queue.empty(): + get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) + else: + dev, ret = queue.get_nowait() + results[dev] = ret devs = left + + return results diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index 908e5b3a..d6213653 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -195,3 +195,24 @@ def test_param_pack_concat(): offsets = mge.Tensor(offsets_val, np.int32) c = param_pack_concat([a, b], offsets, offsets_val) assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy()) + + +@pytest.mark.require_ngpu(2) +@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"]) +@pytest.mark.isolated_distributed +def test_collect_results(early_return): + @dist.launcher + def worker(): + if early_return: + exit(0) + return (dist.get_rank(), dist.get_world_size()) + + results = worker() + world_size = len(results) + assert world_size > 0 + expects = ( + [None] * world_size + if early_return + else [(dev, world_size) for dev in range(world_size)] + ) + assert results == expects -- GitLab