From 207a346351d1d7daa74435f43b1992f86331d650 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 Jan 2021 18:27:50 +0800 Subject: [PATCH] chore(mge): run get_device_count("gpu") in subprocess GitOrigin-RevId: 0f0dc001cfc45fc0d04de1a86c27f8bba8185d6b --- .../python/megengine/distributed/helper.py | 18 ---- .../python/megengine/distributed/launcher.py | 7 +- .../python/megengine/functional/tensor.py | 4 +- imperative/python/test/conftest.py | 4 +- .../test/integration/test_param_pack.py | 1 - .../test/unit/autodiff/test_grad_manger.py | 1 - .../python/test/unit/core/test_autodiff.py | 1 - .../python/test/unit/core/test_dtype_quant.py | 5 +- .../test/unit/distributed/test_distributed.py | 6 +- .../test/unit/functional/test_functional.py | 9 +- .../functional/test_functional_distributed.py | 1 - .../test/unit/functional/test_tensor.py | 1 - .../python/test/unit/module/test_batchnorm.py | 1 - .../python/test/unit/module/test_qat.py | 6 +- .../test/unit/quantization/test_observer.py | 6 +- .../python/test/unit/quantization/test_op.py | 4 +- .../python/test/unit/random/test_rng.py | 26 +++--- .../test/unit/utils/test_network_node.py | 5 +- src/core/impl/comp_node/cuda/comp_node.cpp | 91 +++++++++++++++++-- 19 files changed, 119 insertions(+), 78 deletions(-) diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index f2d83ba09..96d79a6c0 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -181,11 +181,6 @@ def synchronized(func: Callable): return wrapper -def _get_device_count_worker(queue, device_type): - num = get_device_count(device_type) - queue.put(num) - - def _check_device_initialized(device_type: str, rank: int): try: test = Tensor(1, device=(device_type + str(rank))) @@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int): raise RuntimeError(errmsg) -def get_device_count_by_fork(device_type: str): - """ - Get device count in fork thread. - See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork - for more information. - """ - q = mp.Queue() - p = mp.Process(target=_get_device_count_worker, args=(q, device_type)) - p.start() - p.join() - return q.get() - - def bcast_list_(inps: list, group: Group = WORLD): """ Broadcast tensors between given group. diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 3e6d2b18d..f963bd611 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -13,9 +13,10 @@ import queue from .. import _exit from ..core._imperative_rt.core2 import full_sync +from ..device import get_device_count from ..logger import get_logger from .group import _set_machine_ranks, group_barrier, init_process_group -from .helper import _check_device_initialized, get_device_count_by_fork +from .helper import _check_device_initialized from .server import Client, Server WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( @@ -91,9 +92,7 @@ class launcher: backend="auto", ): self.func = func - self.n_gpus = ( - n_gpus if n_gpus is not None else get_device_count_by_fork(device_type) - ) + self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type) self.world_size = world_size if world_size is not None else self.n_gpus self.rank_start = rank_start self.master_ip = master_ip diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index ae07da0af..c4ac041f0 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -1188,11 +1188,11 @@ def copy(inp, device=None): import numpy as np import platform from megengine import tensor - from megengine.distributed.helper import get_device_count_by_fork + from megengine.device import get_device_count import megengine.functional as F x = tensor([1, 2, 3], np.int32) - if 1 == get_device_count_by_fork("gpu"): + if 1 == get_device_count("gpu"): y = F.copy(x, "cpu1") print(y.numpy()) else: diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index ed598cedb..518d0584e 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -15,7 +15,7 @@ import megengine.functional import megengine.module from megengine import Parameter from megengine.core._imperative_rt.core2 import sync -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.experimental.autograd import ( disable_higher_order_directive, enable_higher_order_directive, @@ -25,7 +25,7 @@ from megengine.module import Linear, Module sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) -_ngpu = get_device_count_by_fork("gpu") +_ngpu = get_device_count("gpu") @pytest.fixture(autouse=True) diff --git a/imperative/python/test/integration/test_param_pack.py b/imperative/python/test/integration/test_param_pack.py index 42ffc1557..e672c5362 100644 --- a/imperative/python/test/integration/test_param_pack.py +++ b/imperative/python/test/integration/test_param_pack.py @@ -16,7 +16,6 @@ import megengine.autodiff as ad import megengine.distributed as dist import megengine.optimizer as optimizer from megengine import Parameter, tensor -from megengine.distributed.helper import get_device_count_by_fork from megengine.module import Module from megengine.optimizer import SGD diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 8511567d9..daf58823f 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -18,7 +18,6 @@ import megengine.functional as F import megengine.module as M import megengine.optimizer as optim from megengine.autodiff import GradManager -from megengine.distributed.helper import get_device_count_by_fork from megengine.jit import trace diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index fa90ab40e..192038b34 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core.autodiff.grad import Grad from megengine.core.ops.builtin import Elemwise, Identity -from megengine.distributed.helper import get_device_count_by_fork from megengine.functional.distributed import remote_recv, remote_send diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index 594d86e5a..0fbd99e23 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import ( quint4, quint8, ) -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.tensor import Tensor @@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle(): @pytest.mark.skipif( - get_device_count_by_fork("gpu") != 0, - reason="TypeCvt to quint4 is not supported on GPU", + get_device_count("gpu") != 0, reason="TypeCvt to quint4 is not supported on GPU", ) def test_quint4_typecvt(): device = "xpux" diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index ce51f0107..0fcae8645 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -17,11 +17,7 @@ import megengine as mge import megengine.distributed as dist from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.device import get_default_device -from megengine.distributed.helper import ( - get_device_count_by_fork, - param_pack_concat, - param_pack_split, -) +from megengine.distributed.helper import param_pack_concat, param_pack_split def _assert_q_empty(q): diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 7bbdc6ad5..a4366a6c4 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad from megengine.core.tensor.utils import make_shape_tuple -from megengine.distributed.helper import get_device_count_by_fork -from megengine.jit import trace +from megengine.device import get_device_count def test_where(): @@ -613,7 +612,7 @@ def test_nms(): @pytest.mark.skipif( - get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" + get_device_count("gpu") > 0, reason="cuda does not support nchw int8" ) def test_conv_bias(): inp_scale = 1.5 @@ -715,9 +714,7 @@ def test_conv_bias(): run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") -@pytest.mark.skipif( - get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" -) +@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") def test_batch_conv_bias(): inp_scale = 1.5 w_scale = 2.5 diff --git a/imperative/python/test/unit/functional/test_functional_distributed.py b/imperative/python/test/unit/functional/test_functional_distributed.py index 0eaaaa056..4c15b37d0 100644 --- a/imperative/python/test/unit/functional/test_functional_distributed.py +++ b/imperative/python/test/unit/functional/test_functional_distributed.py @@ -16,7 +16,6 @@ import megengine.distributed as dist from megengine import Parameter, tensor from megengine.core._imperative_rt.core2 import sync from megengine.device import get_default_device, set_default_device -from megengine.distributed.helper import get_device_count_by_fork from megengine.functional.distributed import ( all_gather, all_reduce_max, diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 63965d37b..dbc3fa83d 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -18,7 +18,6 @@ from megengine import tensor from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d -from megengine.distributed.helper import get_device_count_by_fork from megengine.jit import trace from megengine.utils.network import Network, set_symbolic_shape from megengine.utils.network_node import VarNode diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index 822ebe0b4..12d61a4f3 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -16,7 +16,6 @@ import megengine as mge import megengine.distributed as dist from megengine import Tensor from megengine.core._trace_option import use_symbolic_shape -from megengine.distributed.helper import get_device_count_by_fork from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 9bc60e2bf..51b5206e5 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -6,7 +6,7 @@ import pytest import megengine.utils.comp_graph_tools as cgtools from megengine import jit, tensor -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.functional import expand_dims from megengine.module import ( BatchMatMulActivation, @@ -101,9 +101,7 @@ def test_qat_conv(): np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) -@pytest.mark.skipif( - get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" -) +@pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") def test_qat_batchmatmul_activation(): batch = 4 in_features = 8 diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index 691e701c2..cddf07a61 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -13,7 +13,7 @@ import pytest import megengine as mge import megengine.distributed as dist -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.quantization import QuantMode, create_qparams from megengine.quantization.observer import ( ExponentialMovingAverageObserver, @@ -78,7 +78,7 @@ def test_passive_observer(): @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed def test_sync_min_max_observer(): - word_size = get_device_count_by_fork("gpu") + word_size = get_device_count("gpu") x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") np_min, np_max = x.min(), x.max() @@ -96,7 +96,7 @@ def test_sync_min_max_observer(): @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed def test_sync_exponential_moving_average_observer(): - word_size = get_device_count_by_fork("gpu") + word_size = get_device_count("gpu") t = np.random.rand() x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index a6a10a912..beb151d9e 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -12,7 +12,7 @@ import pytest import megengine as mge import megengine.functional as F from megengine.core.tensor import dtype -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.functional.elemwise import _elemwise_multi_type, _elwise from megengine.quantization import QuantMode, create_qparams @@ -68,7 +68,7 @@ def test_elemwise(kind): @pytest.mark.skipif( - get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" + get_device_count("gpu") > 0, reason="cuda does not support nchw int8" ) def test_conv_bias(): inp_scale = np.float32(np.random.rand() + 1) diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 2e5bd26b3..3150ef8ad 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -26,12 +26,12 @@ from megengine.core.ops.builtin import ( PoissonRNG, UniformRNG, ) -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count from megengine.random import RNG, seed, uniform @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_gaussian_op(): shape = ( @@ -61,7 +61,7 @@ def test_gaussian_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_uniform_op(): shape = ( @@ -89,7 +89,7 @@ def test_uniform_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_gamma_op(): _shape, _scale = 2, 0.8 @@ -117,7 +117,7 @@ def test_gamma_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_beta_op(): _alpha, _beta = 2, 0.8 @@ -148,7 +148,7 @@ def test_beta_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_poisson_op(): lam = F.full([8, 9, 11, 12], value=2, dtype="float32") @@ -171,7 +171,7 @@ def test_poisson_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", + get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_permutation_op(): n = 1000 @@ -205,7 +205,7 @@ def test_permutation_op(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_UniformRNG(): m1 = RNG(seed=111, device="xpu0") @@ -233,7 +233,7 @@ def test_UniformRNG(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_NormalRNG(): m1 = RNG(seed=111, device="xpu0") @@ -262,7 +262,7 @@ def test_NormalRNG(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_GammaRNG(): m1 = RNG(seed=111, device="xpu0") @@ -295,7 +295,7 @@ def test_GammaRNG(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_BetaRNG(): m1 = RNG(seed=111, device="xpu0") @@ -330,7 +330,7 @@ def test_BetaRNG(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_PoissonRNG(): m1 = RNG(seed=111, device="xpu0") @@ -359,7 +359,7 @@ def test_PoissonRNG(): @pytest.mark.skipif( - get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", + get_device_count("xpu") <= 1, reason="xpu counts need > 1", ) def test_PermutationRNG(): m1 = RNG(seed=111, device="xpu0") diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index 99b178662..e6c71379b 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -13,8 +13,7 @@ import megengine.random as rand from megengine.core._imperative_rt.core2 import apply from megengine.core._wrap import Device from megengine.core.ops import builtin -from megengine.device import is_cuda_available -from megengine.distributed.helper import get_device_count_by_fork +from megengine.device import get_device_count, is_cuda_available from megengine.functional.external import tensorrt_runtime_opr from megengine.jit.tracing import trace from megengine.tensor import Tensor @@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling(): @pytest.mark.skipif( - get_device_count_by_fork("gpu") > 0, + get_device_count("gpu") > 0, reason="does not support int8 when gpu compute capability less than 6.1", ) def test_convbias(): diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index 64d7aa25f..b6b18268c 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -27,8 +27,14 @@ using namespace mgb; #include +#include #include +#ifdef __unix__ +#include +#include +#endif + using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; namespace { @@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { /* ===================== CudaCompNode static methods ===================== */ +namespace { + +#ifndef __unix__ +CUresult get_device_count_forksafe(int* pcnt) { + cuInit(0); + return cuDeviceGetCount(pcnt); +} +#else +struct RAIICloseFD : NonCopyableObj { + int m_fd = -1; + + RAIICloseFD(int fd) : m_fd(fd) {} + ~RAIICloseFD() {close();} + void close() { + if (m_fd != -1) { + ::close(m_fd); + m_fd = -1; + } + } +}; +// an implementation that does not call cuInit +CUresult get_device_count_forksafe(int* pcnt) { + auto err = cuDeviceGetCount(pcnt); + if (err != CUDA_ERROR_NOT_INITIALIZED) return err; + // cuInit not called, call it in child process + int fd[2]; + mgb_assert(pipe(fd) == 0, "pipe() failed"); + int fdr = fd[0], fdw = fd[1]; + RAIICloseFD fdr_guard(fdr); + RAIICloseFD fdw_guard(fdw); + auto cpid = fork(); + mgb_assert(cpid != -1, "fork() failed"); + if (cpid == 0) { + fdr_guard.close(); + do { + err = cuInit(0); + if (err != CUDA_SUCCESS) break; + err = cuDeviceGetCount(pcnt); + } while (0); + auto sz = write(fdw, &err, sizeof(err)); + if (sz == sizeof(err) && err == CUDA_SUCCESS) { + sz = write(fdw, pcnt, sizeof(*pcnt)); + } + fdw_guard.close(); + std::quick_exit(0); + } + fdw_guard.close(); + auto sz = read(fdr, &err, sizeof(err)); + mgb_assert(sz == sizeof(err), "failed to read error code from child"); + if (err == CUDA_SUCCESS) { + sz = read(fdr, pcnt, sizeof(*pcnt)); + mgb_assert(sz == sizeof(*pcnt), "failed to read device count from child"); + return err; + } + // try again, maybe another thread called cuInit while we fork + auto err2 = cuDeviceGetCount(pcnt); + if (err2 == CUDA_SUCCESS) return err2; + if (err2 == CUDA_ERROR_NOT_INITIALIZED) return err; + return err2; +} +#endif + +const char* cu_get_error_string(CUresult err) { + const char* ret = nullptr; + cuGetErrorString(err, &ret); + if (!ret) ret = "unknown cuda error"; + return ret; +} + +} // namespace + bool CudaCompNode::available() { static int result = -1; static Spinlock mtx; MGB_LOCK_GUARD(mtx); if (result == -1) { int ndev = -1; - auto err = cudaGetDeviceCount(&ndev); - result = err == cudaSuccess && ndev > 0; + auto err = get_device_count_forksafe(&ndev); + result = err == CUDA_SUCCESS && ndev > 0; if (!result) { mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", - cudaGetErrorString(err), static_cast(err), ndev); + cu_get_error_string(err), static_cast(err), ndev); } - if (err == cudaErrorInitializationError) { + if (err == CUDA_ERROR_NOT_INITIALIZED) { mgb_throw(std::runtime_error, "cuda initialization error."); } } @@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) { static Spinlock mtx; MGB_LOCK_GUARD(mtx); if (cnt == -1) { - auto err = cudaGetDeviceCount(&cnt); - if (err != cudaSuccess) { + auto err = get_device_count_forksafe(&cnt); + if (err != CUDA_SUCCESS) { if (warn) mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", - cudaGetErrorString(err), int(err)); + cu_get_error_string(err), int(err)); cnt = 0; } mgb_assert(cnt >= 0); -- GitLab