libcuda.cpp 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#pragma GCC visibility push(default)

#include <cstdio>
#define LOGE(fmt, v...) fprintf(stderr, "err: " fmt "\n", ##v)

extern "C" {
#include <cuda.h>
}
#include <cudaProfiler.h>

#pragma GCC diagnostic ignored "-Wdeprecated-declarations"

#if defined(_WIN32)
#include <windows.h>
#define RTLD_LAZY 0

static void* dlopen(const char* file, int) {
18
    return static_cast<void*>(LoadLibraryA(file));
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
}

static void* dlerror() {
    const char* errmsg = "dlerror not aviable in windows";
    return const_cast<char*>(errmsg);
}

static void* dlsym(void* handle, const char* name) {
    FARPROC symbol = GetProcAddress((HMODULE)handle, name);
    return reinterpret_cast<void*>(symbol);
}

#else
#include <dlfcn.h>
#include <unistd.h>
#endif

static void log_failed_load(int func_idx);
namespace {
template <typename T>
T on_init_failed(int func_idx);
template <>
CUresult on_init_failed(int func_idx) {
    log_failed_load(func_idx);
    return CUDA_ERROR_UNKNOWN;
}
}

#define _WRAPLIB_API_CALL CUDAAPI
#define _WRAPLIB_CALLBACK CUDA_CB
#include "./libcuda-wrap.h"
#undef _WRAPLIB_CALLBACK
#undef _WRAPLIB_API_CALL

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
// Harvested from cuda_drvapi_dynlink.c
static const char* default_so_paths[] = {
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
    "nvcuda.dll",
#elif defined(__unix__) || defined (__QNX__) || defined(__APPLE__) || defined(__MACOSX)
#if defined(__APPLE__) || defined(__MACOSX)
    "/usr/local/cuda/lib/libcuda.dylib",
#elif defined(__ANDROID__)
#if defined (__aarch64__)
    "/system/vendor/lib64/libcuda.so",
#elif defined(__arm__)
    "/system/vendor/lib/libcuda.so",
#endif
#else
    "libcuda.so.1",
    
    // In case some users does not have correct search path configured in
    // /etc/ld.so.conf
    "/usr/lib/x86_64-linux-gnu/libcuda.so",
    "/usr/local/nvidia/lib64/libcuda.so",
#endif
#else
#error "Unknown platform"
#endif
};
78 79 80

static void* get_library_handle() {
    void* handle = nullptr;
81 82 83 84
    for (size_t i = 0; i < (sizeof(default_so_paths) / sizeof(char*)); i++) {
        handle = dlopen(default_so_paths[i], RTLD_LAZY);
        if (handle) {
            break;
85 86 87
        }
    }

88 89
    if (!handle) {
        LOGE("Failed to load CUDA Driver API library");
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        return nullptr;
    }
    return handle;
}

static void log_failed_load(int func_idx) {
    LOGE("failed to load cuda func: %s", g_func_name[func_idx]);
}

static void* resolve_library_func(void* handle, const char* func) {
    if (!handle) {
        LOGE("handle should not be nullptr!");
        return nullptr;
    }
    auto ret = dlsym(handle, func);
    if (!ret) {
        LOGE("failed to load cuda func: %s", func);
    }
    return ret;
}