提交 17323dbd 编写于 作者: M Megvii Engine Team

feat(dist): collect return values in dist.launcher

GitOrigin-RevId: 519e768ce916a28ea3910f4d1437ea443be4472f
上级 5b697c71
...@@ -8,15 +8,30 @@ ...@@ -8,15 +8,30 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import multiprocessing as mp import multiprocessing as mp
import queue
from ..core._imperative_rt.core2 import sync from ..core._imperative_rt.core2 import sync
from ..logger import get_logger
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
from .server import Client, Server from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
"subprocess exited with code 0 but did not return a value"
)
def _run_wrapped( def _run_wrapped(
func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs func,
is_multimachine,
master_ip,
port,
world_size,
rank,
dev,
args,
kwargs,
queue: mp.Queue,
): ):
"""Init distributed process group and run wrapped function.""" """Init distributed process group and run wrapped function."""
init_process_group( init_process_group(
...@@ -24,7 +39,8 @@ def _run_wrapped( ...@@ -24,7 +39,8 @@ def _run_wrapped(
) )
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
func(*args, **kwargs) ret = func(*args, **kwargs)
queue.put((dev, ret))
sync() sync()
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
...@@ -70,6 +86,8 @@ class launcher: ...@@ -70,6 +86,8 @@ class launcher:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
procs = [] procs = []
queue = mp.Queue(self.n_gpus)
results = [None] * self.n_gpus
for dev in range(self.n_gpus): for dev in range(self.n_gpus):
p = mp.Process( p = mp.Process(
target=_run_wrapped, target=_run_wrapped,
...@@ -83,6 +101,7 @@ class launcher: ...@@ -83,6 +101,7 @@ class launcher:
dev, dev,
args, args,
kwargs, kwargs,
queue,
), ),
) )
p.start() p.start()
...@@ -90,6 +109,11 @@ class launcher: ...@@ -90,6 +109,11 @@ class launcher:
devs = list(range(self.n_gpus)) devs = list(range(self.n_gpus))
def terminate():
for dev in devs:
procs[dev].terminate()
devs.clear()
while len(devs) > 0: while len(devs) > 0:
left = [] left = []
# check all processes in one second # check all processes in one second
...@@ -99,11 +123,17 @@ class launcher: ...@@ -99,11 +123,17 @@ class launcher:
code = procs[dev].exitcode code = procs[dev].exitcode
# terminate processes if one of them has failed # terminate processes if one of them has failed
if code != 0 and code != None: if code != 0 and code != None:
for i in devs: terminate()
procs[i].terminate()
assert ( assert (
code == 0 or code == None code == 0 or code == None
), "subprocess {} exit with code {}".format(dev + self.rank_start, code) ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
if code == None: if code == None:
left.append(dev) left.append(dev)
elif queue.empty():
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
else:
dev, ret = queue.get_nowait()
results[dev] = ret
devs = left devs = left
return results
...@@ -195,3 +195,24 @@ def test_param_pack_concat(): ...@@ -195,3 +195,24 @@ def test_param_pack_concat():
offsets = mge.Tensor(offsets_val, np.int32) offsets = mge.Tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val) c = param_pack_concat([a, b], offsets, offsets_val)
assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy()) assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"])
@pytest.mark.isolated_distributed
def test_collect_results(early_return):
@dist.launcher
def worker():
if early_return:
exit(0)
return (dist.get_rank(), dist.get_world_size())
results = worker()
world_size = len(results)
assert world_size > 0
expects = (
[None] * world_size
if early_return
else [(dev, world_size) for dev in range(world_size)]
)
assert results == expects
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册