提交 17323dbd 编写于 作者: M Megvii Engine Team

feat(dist): collect return values in dist.launcher

GitOrigin-RevId: 519e768ce916a28ea3910f4d1437ea443be4472f
上级 5b697c71
......@@ -8,15 +8,30 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
import queue
from ..core._imperative_rt.core2 import sync
from ..logger import get_logger
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
"subprocess exited with code 0 but did not return a value"
)
def _run_wrapped(
func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs
func,
is_multimachine,
master_ip,
port,
world_size,
rank,
dev,
args,
kwargs,
queue: mp.Queue,
):
"""Init distributed process group and run wrapped function."""
init_process_group(
......@@ -24,7 +39,8 @@ def _run_wrapped(
)
if is_multimachine:
group_barrier()
func(*args, **kwargs)
ret = func(*args, **kwargs)
queue.put((dev, ret))
sync()
if is_multimachine:
group_barrier()
......@@ -70,6 +86,8 @@ class launcher:
def __call__(self, *args, **kwargs):
procs = []
queue = mp.Queue(self.n_gpus)
results = [None] * self.n_gpus
for dev in range(self.n_gpus):
p = mp.Process(
target=_run_wrapped,
......@@ -83,6 +101,7 @@ class launcher:
dev,
args,
kwargs,
queue,
),
)
p.start()
......@@ -90,6 +109,11 @@ class launcher:
devs = list(range(self.n_gpus))
def terminate():
for dev in devs:
procs[dev].terminate()
devs.clear()
while len(devs) > 0:
left = []
# check all processes in one second
......@@ -99,11 +123,17 @@ class launcher:
code = procs[dev].exitcode
# terminate processes if one of them has failed
if code != 0 and code != None:
for i in devs:
procs[i].terminate()
terminate()
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
if code == None:
left.append(dev)
elif queue.empty():
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
else:
dev, ret = queue.get_nowait()
results[dev] = ret
devs = left
return results
......@@ -195,3 +195,24 @@ def test_param_pack_concat():
offsets = mge.Tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"])
@pytest.mark.isolated_distributed
def test_collect_results(early_return):
@dist.launcher
def worker():
if early_return:
exit(0)
return (dist.get_rank(), dist.get_world_size())
results = worker()
world_size = len(results)
assert world_size > 0
expects = (
[None] * world_size
if early_return
else [(dev, world_size) for dev in range(world_size)]
)
assert results == expects
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册