cuda_helper.cpp 4.8 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "megbrain/utils/cuda_helper.h"
2 3 4 5
#include "megbrain/common.h"
#include "megbrain/exception.h"

#include <fstream>
M
Megvii Engine Team 已提交
6
#include <set>
7
#include <sstream>
M
Megvii Engine Team 已提交
8
#include <string>
9 10 11 12 13 14

#ifdef WIN32
#include <io.h>
#include <windows.h>
#else
#include <dlfcn.h>
M
Megvii Engine Team 已提交
15
#include <unistd.h>
16 17
#endif

18 19
#if MGB_CUDA

20 21
namespace {

22 23 24 25 26
#ifndef PATH_MAX
#define PATH_MAX 4096
#endif

#ifdef WIN32
M
Megvii Engine Team 已提交
27 28 29 30
#define F_OK           0
#define RTLD_LAZY      0
#define RTLD_GLOBAL    0
#define RTLD_NOLOAD    0
31
#define RTLD_DI_ORIGIN 0
M
Megvii Engine Team 已提交
32 33 34 35 36
#define access(a, b)   false
#define SPLITER        ';'
#define PATH_SPLITER   '\\'
#define ENV_PATH       "Path"
#define NVCC_EXE       "nvcc.exe"
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
void* dlopen(const char* file, int) {
    return static_cast<void*>(LoadLibrary(file));
}

int dlinfo(void* handle, int request, char* path) {
    if (GetModuleFileName((HMODULE)handle, path, PATH_MAX))
        return 0;
    else
        return -1;
}

void* dlerror() {
    const char* errmsg = "dlerror not aviable in windows";
    return const_cast<char*>(errmsg);
}

int check_file_exist(const char* path, int mode) {
    return _access(path, mode);
}
#else
M
Megvii Engine Team 已提交
58
#define SPLITER      ':'
59
#define PATH_SPLITER '/'
M
Megvii Engine Team 已提交
60 61
#define ENV_PATH     "PATH"
#define NVCC_EXE     "nvcc"
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
int check_file_exist(const char* path, int mode) {
    return access(path, mode);
}
#endif

std::vector<std::string> split_env(const char* env) {
    std::string e(env);
    std::istringstream stream(e);
    std::vector<std::string> ret;
    std::string path;
    while (std::getline(stream, path, SPLITER)) {
        ret.emplace_back(path);
    }
    return ret;
}

//! this function will find file_name in each path in envs. It accepts add
//! intermediate path between env and file_name
std::string find_file_in_envs_with_intmd(
        const std::vector<std::string>& envs, const std::string& file_name,
        const std::vector<std::string>& itmedias = {}) {
    for (auto&& env : envs) {
        auto ret = getenv(env.c_str());
        if (ret) {
            for (auto&& path : split_env(ret)) {
                auto file_path = std::string(path) + PATH_SPLITER + file_name;
                if (!check_file_exist(file_path.c_str(), F_OK)) {
                    return file_path;
                }
                if (!itmedias.empty()) {
                    for (auto&& inter_path : itmedias) {
M
Megvii Engine Team 已提交
93 94
                        file_path = std::string(path) + PATH_SPLITER + inter_path +
                                    PATH_SPLITER + file_name;
95 96 97 98 99 100 101 102 103 104 105 106 107 108
                        if (!check_file_exist(file_path.c_str(), F_OK)) {
                            return file_path;
                        }
                    }
                }
            }
        }
    }
    return std::string{};
}

std::string get_nvcc_root_path() {
    auto nvcc_root_path = find_file_in_envs_with_intmd({ENV_PATH}, NVCC_EXE);
    if (nvcc_root_path.empty()) {
109
        return {};
110 111 112 113 114 115
    } else {
        auto idx = nvcc_root_path.rfind(PATH_SPLITER);
        return nvcc_root_path.substr(0, idx + 1);
    }
}

116 117
}  // namespace

118 119 120 121 122 123
std::vector<std::string> mgb::get_cuda_include_path() {
    std::vector<std::string> paths;
    // 1. use CUDA_BIN_PATH
    auto cuda_path = getenv("CUDA_BIN_PATH");
    if (cuda_path) {
        paths.emplace_back(std::string(cuda_path) + PATH_SPLITER + "include");
M
Megvii Engine Team 已提交
124 125 126
        paths.emplace_back(
                std::string(cuda_path) + PATH_SPLITER + ".." + PATH_SPLITER +
                "include");
127 128 129 130
    }

    // 2. use nvcc path
    auto nvcc_path = get_nvcc_root_path();
131 132 133 134 135 136 137 138 139
    if (!nvcc_path.empty()) {
        auto cudart_header_path = nvcc_path + ".." + PATH_SPLITER + "include" +
                                  PATH_SPLITER + "cuda_runtime.h";
        //! double check path_to_nvcc/../include/cuda_runtime.h exists
        auto ret = check_file_exist(cudart_header_path.c_str(), F_OK);
        if (ret == 0) {
            paths.emplace_back(nvcc_path + "..");
            paths.emplace_back(nvcc_path + ".." + PATH_SPLITER + "include");
        }
140 141 142 143 144
    }

    // 3. use libcudart.so library path
    char cuda_lib_path[PATH_MAX];
    auto handle = dlopen("libcudart.so", RTLD_GLOBAL | RTLD_LAZY);
M
Megvii Engine Team 已提交
145 146 147 148 149 150
    if (handle != nullptr) {
        mgb_assert(
                dlinfo(handle, RTLD_DI_ORIGIN, cuda_lib_path) != -1, "%s", dlerror());
        paths.emplace_back(
                std::string(cuda_lib_path) + PATH_SPLITER + ".." + PATH_SPLITER +
                "include");
151
    }
M
Megvii Engine Team 已提交
152 153 154 155 156 157 158
    mgb_assert(
            paths.size() > 0,
            "can't find cuda include path, check your environment of cuda, "
            "try one of this solutions "
            "1. set CUDA_BIN_PATH to cuda home path "
            "2. add nvcc path in PATH "
            "3. add libcudart.so path in LD_LIBRARY_PATH");
159
    return paths;
160 161
}

162
#else
163 164

std::vector<std::string> mgb::get_cuda_include_path() {
165
    mgb_throw(MegBrainError, "cuda disabled at compile time");
166
}
167 168

#endif