提交 207a3463 编写于 作者: M Megvii Engine Team

chore(mge): run get_device_count("gpu") in subprocess

GitOrigin-RevId: 0f0dc001cfc45fc0d04de1a86c27f8bba8185d6b
上级 869a0327
...@@ -181,11 +181,6 @@ def synchronized(func: Callable): ...@@ -181,11 +181,6 @@ def synchronized(func: Callable):
return wrapper return wrapper
def _get_device_count_worker(queue, device_type):
num = get_device_count(device_type)
queue.put(num)
def _check_device_initialized(device_type: str, rank: int): def _check_device_initialized(device_type: str, rank: int):
try: try:
test = Tensor(1, device=(device_type + str(rank))) test = Tensor(1, device=(device_type + str(rank)))
...@@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int): ...@@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int):
raise RuntimeError(errmsg) raise RuntimeError(errmsg)
def get_device_count_by_fork(device_type: str):
"""
Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q = mp.Queue()
p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
p.start()
p.join()
return q.get()
def bcast_list_(inps: list, group: Group = WORLD): def bcast_list_(inps: list, group: Group = WORLD):
""" """
Broadcast tensors between given group. Broadcast tensors between given group.
......
...@@ -13,9 +13,10 @@ import queue ...@@ -13,9 +13,10 @@ import queue
from .. import _exit from .. import _exit
from ..core._imperative_rt.core2 import full_sync from ..core._imperative_rt.core2 import full_sync
from ..device import get_device_count
from ..logger import get_logger from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized, get_device_count_by_fork from .helper import _check_device_initialized
from .server import Client, Server from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
...@@ -91,9 +92,7 @@ class launcher: ...@@ -91,9 +92,7 @@ class launcher:
backend="auto", backend="auto",
): ):
self.func = func self.func = func
self.n_gpus = ( self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type)
n_gpus if n_gpus is not None else get_device_count_by_fork(device_type)
)
self.world_size = world_size if world_size is not None else self.n_gpus self.world_size = world_size if world_size is not None else self.n_gpus
self.rank_start = rank_start self.rank_start = rank_start
self.master_ip = master_ip self.master_ip = master_ip
......
...@@ -1188,11 +1188,11 @@ def copy(inp, device=None): ...@@ -1188,11 +1188,11 @@ def copy(inp, device=None):
import numpy as np import numpy as np
import platform import platform
from megengine import tensor from megengine import tensor
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
import megengine.functional as F import megengine.functional as F
x = tensor([1, 2, 3], np.int32) x = tensor([1, 2, 3], np.int32)
if 1 == get_device_count_by_fork("gpu"): if 1 == get_device_count("gpu"):
y = F.copy(x, "cpu1") y = F.copy(x, "cpu1")
print(y.numpy()) print(y.numpy())
else: else:
......
...@@ -15,7 +15,7 @@ import megengine.functional ...@@ -15,7 +15,7 @@ import megengine.functional
import megengine.module import megengine.module
from megengine import Parameter from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync from megengine.core._imperative_rt.core2 import sync
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.experimental.autograd import ( from megengine.experimental.autograd import (
disable_higher_order_directive, disable_higher_order_directive,
enable_higher_order_directive, enable_higher_order_directive,
...@@ -25,7 +25,7 @@ from megengine.module import Linear, Module ...@@ -25,7 +25,7 @@ from megengine.module import Linear, Module
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
_ngpu = get_device_count_by_fork("gpu") _ngpu = get_device_count("gpu")
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
......
...@@ -16,7 +16,6 @@ import megengine.autodiff as ad ...@@ -16,7 +16,6 @@ import megengine.autodiff as ad
import megengine.distributed as dist import megengine.distributed as dist
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.module import Module from megengine.module import Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
......
...@@ -18,7 +18,6 @@ import megengine.functional as F ...@@ -18,7 +18,6 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace from megengine.jit import trace
......
...@@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative ...@@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise, Identity from megengine.core.ops.builtin import Elemwise, Identity
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send
......
...@@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import ( ...@@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import (
quint4, quint4,
quint8, quint8,
) )
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.tensor import Tensor from megengine.tensor import Tensor
...@@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle(): ...@@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("gpu") != 0, get_device_count("gpu") != 0, reason="TypeCvt to quint4 is not supported on GPU",
reason="TypeCvt to quint4 is not supported on GPU",
) )
def test_quint4_typecvt(): def test_quint4_typecvt():
device = "xpux" device = "xpux"
......
...@@ -17,11 +17,7 @@ import megengine as mge ...@@ -17,11 +17,7 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
from megengine.device import get_default_device from megengine.device import get_default_device
from megengine.distributed.helper import ( from megengine.distributed.helper import param_pack_concat, param_pack_split
get_device_count_by_fork,
param_pack_concat,
param_pack_split,
)
def _assert_q_empty(q): def _assert_q_empty(q):
......
...@@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor ...@@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.jit import trace
def test_where(): def test_where():
...@@ -613,7 +612,7 @@ def test_nms(): ...@@ -613,7 +612,7 @@ def test_nms():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
) )
def test_conv_bias(): def test_conv_bias():
inp_scale = 1.5 inp_scale = 1.5
...@@ -715,9 +714,7 @@ def test_conv_bias(): ...@@ -715,9 +714,7 @@ def test_conv_bias():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
@pytest.mark.skipif( @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
)
def test_batch_conv_bias(): def test_batch_conv_bias():
inp_scale = 1.5 inp_scale = 1.5
w_scale = 2.5 w_scale = 2.5
......
...@@ -16,7 +16,6 @@ import megengine.distributed as dist ...@@ -16,7 +16,6 @@ import megengine.distributed as dist
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.core._imperative_rt.core2 import sync from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import ( from megengine.functional.distributed import (
all_gather, all_gather,
all_reduce_max, all_reduce_max,
......
...@@ -18,7 +18,6 @@ from megengine import tensor ...@@ -18,7 +18,6 @@ from megengine import tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network import Network, set_symbolic_shape from megengine.utils.network import Network, set_symbolic_shape
from megengine.utils.network_node import VarNode from megengine.utils.network_node import VarNode
......
...@@ -16,7 +16,6 @@ import megengine as mge ...@@ -16,7 +16,6 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Tensor from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.distributed.helper import get_device_count_by_fork
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import jit, tensor from megengine import jit, tensor
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.functional import expand_dims from megengine.functional import expand_dims
from megengine.module import ( from megengine.module import (
BatchMatMulActivation, BatchMatMulActivation,
...@@ -101,9 +101,7 @@ def test_qat_conv(): ...@@ -101,9 +101,7 @@ def test_qat_conv():
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
@pytest.mark.skipif( @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
)
def test_qat_batchmatmul_activation(): def test_qat_batchmatmul_activation():
batch = 4 batch = 4
in_features = 8 in_features = 8
......
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.quantization import QuantMode, create_qparams from megengine.quantization import QuantMode, create_qparams
from megengine.quantization.observer import ( from megengine.quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
...@@ -78,7 +78,7 @@ def test_passive_observer(): ...@@ -78,7 +78,7 @@ def test_passive_observer():
@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_sync_min_max_observer(): def test_sync_min_max_observer():
word_size = get_device_count_by_fork("gpu") word_size = get_device_count("gpu")
x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max() np_min, np_max = x.min(), x.max()
...@@ -96,7 +96,7 @@ def test_sync_min_max_observer(): ...@@ -96,7 +96,7 @@ def test_sync_min_max_observer():
@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer(): def test_sync_exponential_moving_average_observer():
word_size = get_device_count_by_fork("gpu") word_size = get_device_count("gpu")
t = np.random.rand() t = np.random.rand()
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
......
...@@ -12,7 +12,7 @@ import pytest ...@@ -12,7 +12,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.functional.elemwise import _elemwise_multi_type, _elwise from megengine.functional.elemwise import _elemwise_multi_type, _elwise
from megengine.quantization import QuantMode, create_qparams from megengine.quantization import QuantMode, create_qparams
...@@ -68,7 +68,7 @@ def test_elemwise(kind): ...@@ -68,7 +68,7 @@ def test_elemwise(kind):
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
) )
def test_conv_bias(): def test_conv_bias():
inp_scale = np.float32(np.random.rand() + 1) inp_scale = np.float32(np.random.rand() + 1)
......
...@@ -26,12 +26,12 @@ from megengine.core.ops.builtin import ( ...@@ -26,12 +26,12 @@ from megengine.core.ops.builtin import (
PoissonRNG, PoissonRNG,
UniformRNG, UniformRNG,
) )
from megengine.distributed.helper import get_device_count_by_fork from megengine.device import get_device_count
from megengine.random import RNG, seed, uniform from megengine.random import RNG, seed, uniform
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_gaussian_op(): def test_gaussian_op():
shape = ( shape = (
...@@ -61,7 +61,7 @@ def test_gaussian_op(): ...@@ -61,7 +61,7 @@ def test_gaussian_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_uniform_op(): def test_uniform_op():
shape = ( shape = (
...@@ -89,7 +89,7 @@ def test_uniform_op(): ...@@ -89,7 +89,7 @@ def test_uniform_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_gamma_op(): def test_gamma_op():
_shape, _scale = 2, 0.8 _shape, _scale = 2, 0.8
...@@ -117,7 +117,7 @@ def test_gamma_op(): ...@@ -117,7 +117,7 @@ def test_gamma_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_beta_op(): def test_beta_op():
_alpha, _beta = 2, 0.8 _alpha, _beta = 2, 0.8
...@@ -148,7 +148,7 @@ def test_beta_op(): ...@@ -148,7 +148,7 @@ def test_beta_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_poisson_op(): def test_poisson_op():
lam = F.full([8, 9, 11, 12], value=2, dtype="float32") lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
...@@ -171,7 +171,7 @@ def test_poisson_op(): ...@@ -171,7 +171,7 @@ def test_poisson_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", get_device_count("xpu") <= 2, reason="xpu counts need > 2",
) )
def test_permutation_op(): def test_permutation_op():
n = 1000 n = 1000
...@@ -205,7 +205,7 @@ def test_permutation_op(): ...@@ -205,7 +205,7 @@ def test_permutation_op():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_UniformRNG(): def test_UniformRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
...@@ -233,7 +233,7 @@ def test_UniformRNG(): ...@@ -233,7 +233,7 @@ def test_UniformRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_NormalRNG(): def test_NormalRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
...@@ -262,7 +262,7 @@ def test_NormalRNG(): ...@@ -262,7 +262,7 @@ def test_NormalRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_GammaRNG(): def test_GammaRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
...@@ -295,7 +295,7 @@ def test_GammaRNG(): ...@@ -295,7 +295,7 @@ def test_GammaRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_BetaRNG(): def test_BetaRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
...@@ -330,7 +330,7 @@ def test_BetaRNG(): ...@@ -330,7 +330,7 @@ def test_BetaRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_PoissonRNG(): def test_PoissonRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
...@@ -359,7 +359,7 @@ def test_PoissonRNG(): ...@@ -359,7 +359,7 @@ def test_PoissonRNG():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", get_device_count("xpu") <= 1, reason="xpu counts need > 1",
) )
def test_PermutationRNG(): def test_PermutationRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
......
...@@ -13,8 +13,7 @@ import megengine.random as rand ...@@ -13,8 +13,7 @@ import megengine.random as rand
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._wrap import Device from megengine.core._wrap import Device
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.device import is_cuda_available from megengine.device import get_device_count, is_cuda_available
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.external import tensorrt_runtime_opr from megengine.functional.external import tensorrt_runtime_opr
from megengine.jit.tracing import trace from megengine.jit.tracing import trace
from megengine.tensor import Tensor from megengine.tensor import Tensor
...@@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling(): ...@@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count_by_fork("gpu") > 0, get_device_count("gpu") > 0,
reason="does not support int8 when gpu compute capability less than 6.1", reason="does not support int8 when gpu compute capability less than 6.1",
) )
def test_convbias(): def test_convbias():
......
...@@ -27,8 +27,14 @@ using namespace mgb; ...@@ -27,8 +27,14 @@ using namespace mgb;
#include <thread> #include <thread>
#include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#ifdef __unix__
#include <unistd.h>
#include <sys/wait.h>
#endif
using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; using CudaCompNodeImpl = CudaCompNode::CompNodeImpl;
namespace { namespace {
...@@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { ...@@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) {
/* ===================== CudaCompNode static methods ===================== */ /* ===================== CudaCompNode static methods ===================== */
namespace {
#ifndef __unix__
CUresult get_device_count_forksafe(int* pcnt) {
cuInit(0);
return cuDeviceGetCount(pcnt);
}
#else
struct RAIICloseFD : NonCopyableObj {
int m_fd = -1;
RAIICloseFD(int fd) : m_fd(fd) {}
~RAIICloseFD() {close();}
void close() {
if (m_fd != -1) {
::close(m_fd);
m_fd = -1;
}
}
};
// an implementation that does not call cuInit
CUresult get_device_count_forksafe(int* pcnt) {
auto err = cuDeviceGetCount(pcnt);
if (err != CUDA_ERROR_NOT_INITIALIZED) return err;
// cuInit not called, call it in child process
int fd[2];
mgb_assert(pipe(fd) == 0, "pipe() failed");
int fdr = fd[0], fdw = fd[1];
RAIICloseFD fdr_guard(fdr);
RAIICloseFD fdw_guard(fdw);
auto cpid = fork();
mgb_assert(cpid != -1, "fork() failed");
if (cpid == 0) {
fdr_guard.close();
do {
err = cuInit(0);
if (err != CUDA_SUCCESS) break;
err = cuDeviceGetCount(pcnt);
} while (0);
auto sz = write(fdw, &err, sizeof(err));
if (sz == sizeof(err) && err == CUDA_SUCCESS) {
sz = write(fdw, pcnt, sizeof(*pcnt));
}
fdw_guard.close();
std::quick_exit(0);
}
fdw_guard.close();
auto sz = read(fdr, &err, sizeof(err));
mgb_assert(sz == sizeof(err), "failed to read error code from child");
if (err == CUDA_SUCCESS) {
sz = read(fdr, pcnt, sizeof(*pcnt));
mgb_assert(sz == sizeof(*pcnt), "failed to read device count from child");
return err;
}
// try again, maybe another thread called cuInit while we fork
auto err2 = cuDeviceGetCount(pcnt);
if (err2 == CUDA_SUCCESS) return err2;
if (err2 == CUDA_ERROR_NOT_INITIALIZED) return err;
return err2;
}
#endif
const char* cu_get_error_string(CUresult err) {
const char* ret = nullptr;
cuGetErrorString(err, &ret);
if (!ret) ret = "unknown cuda error";
return ret;
}
} // namespace
bool CudaCompNode::available() { bool CudaCompNode::available() {
static int result = -1; static int result = -1;
static Spinlock mtx; static Spinlock mtx;
MGB_LOCK_GUARD(mtx); MGB_LOCK_GUARD(mtx);
if (result == -1) { if (result == -1) {
int ndev = -1; int ndev = -1;
auto err = cudaGetDeviceCount(&ndev); auto err = get_device_count_forksafe(&ndev);
result = err == cudaSuccess && ndev > 0; result = err == CUDA_SUCCESS && ndev > 0;
if (!result) { if (!result) {
mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", mgb_log_warn("cuda unavailable: %s(%d) ndev=%d",
cudaGetErrorString(err), static_cast<int>(err), ndev); cu_get_error_string(err), static_cast<int>(err), ndev);
} }
if (err == cudaErrorInitializationError) { if (err == CUDA_ERROR_NOT_INITIALIZED) {
mgb_throw(std::runtime_error, "cuda initialization error."); mgb_throw(std::runtime_error, "cuda initialization error.");
} }
} }
...@@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) { ...@@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) {
static Spinlock mtx; static Spinlock mtx;
MGB_LOCK_GUARD(mtx); MGB_LOCK_GUARD(mtx);
if (cnt == -1) { if (cnt == -1) {
auto err = cudaGetDeviceCount(&cnt); auto err = get_device_count_forksafe(&cnt);
if (err != cudaSuccess) { if (err != CUDA_SUCCESS) {
if (warn) if (warn)
mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", mgb_log_error("cudaGetDeviceCount failed: %s (err %d)",
cudaGetErrorString(err), int(err)); cu_get_error_string(err), int(err));
cnt = 0; cnt = 0;
} }
mgb_assert(cnt >= 0); mgb_assert(cnt >= 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册