提交 933dd9a4 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): add cuda env check before forked thread

style(core/comp_node): reformat code

GitOrigin-RevId: 372452a8eb9e84a2e82d466074f80f78d70531e8
上级 2a541961
......@@ -165,6 +165,18 @@ def _get_device_count_worker(queue, device_type):
queue.put(num)
def _check_device_initialized(device_type: str):
try:
test = Tensor(1, device=device_type)
inited = False
del test
except:
inited = True
errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
if inited:
raise RuntimeError(errmsg)
def get_device_count_by_fork(device_type: str):
"""
Get device count in fork thread.
......
......@@ -15,7 +15,7 @@ from .. import _exit
from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .helper import _check_device_initialized, get_device_count_by_fork
from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
......@@ -37,6 +37,7 @@ def _run_wrapped(
queue: mp.Queue,
):
"""Init distributed process group and run wrapped function."""
_check_device_initialized(device_type)
init_process_group(
master_ip=master_ip,
port=port,
......
......@@ -246,3 +246,16 @@ def test_io_remote(shape):
val = np.random.random_sample(shape).astype("float32")
worker(val, shape)
@pytest.mark.require_ngpu(2)
def test_cuda_init_before_fork():
a = mge.tensor(1, device="gpu0")
@dist.launcher(n_gpus=2)
def worker():
a += 1
b = mge.tensor(2)
with pytest.raises(AssertionError):
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册