提交 e6943017 编写于 作者: M Megvii Engine Team

fix(ops/jit): skip lookup include path when nvcc executable not found

GitOrigin-RevId: f5e7dce1c52fe099de8e90684bbe5b0566a5115a
上级 02c1a0c3
import functools
import os
import platform
import subprocess
import sys
import numpy as np
import pytest
......@@ -152,3 +156,56 @@ def test_subgraph_jit_backward():
y1 = x1 * x1
y2 = mul(x2, x2)
gm.backward(y2)
@pytest.mark.skipif(
platform.system() != "Linux", reason="jit fusion is only available on Linux",
)
def test_subgraph_jit():
prog = """
import megengine
import numpy as np
from megengine.core.tensor.utils import subgraph_fn
# 3 * 4 * 5 > MEGDNN_MAX_NDIM
x_np = np.random.rand(3, 4, 5).astype("float32")
x1 = megengine.Tensor(x_np)
x2 = megengine.Tensor(x_np)
@subgraph_fn(
"Mul",
dtype=x1.dtype,
device=x1.device,
nr_inputs=2,
gopt_level=None,
jit_fusion=True,
custom_grad=True,
)
def mul(inputs, f, c):
x, y = inputs[0:2]
z = f("*", x, y)
(dz,) = yield (z,)
dx = f("*", dz, y)
dy = f("*", dz, x)
yield (dx, dy)
y, = mul(x1, x2)
# ensure execution
y.numpy()
"""
env = dict(os.environ)
if "PATH" in env:
# remove nvcc from environ["PATH"]
path = env["PATH"]
paths = path.split(os.pathsep)
paths = [
path
for path in paths
if not (os.path.isdir(path) and "nvcc" in os.listdir(path))
]
path = os.pathsep.join(paths)
env["PATH"] = path
# previous program may be stored in persistent cache
env["MGE_FASTRUN_CACHE_TYPE"] = "MEMORY"
subprocess.check_call([sys.executable, "-c", prog], env=env)
......@@ -17,6 +17,8 @@ using namespace mgb;
#include <unistd.h>
#endif
namespace {
#ifndef PATH_MAX
#define PATH_MAX 4096
#endif
......@@ -32,6 +34,7 @@ using namespace mgb;
#define PATH_SPLITER '\\'
#define ENV_PATH "Path"
#define NVCC_EXE "nvcc.exe"
void* dlopen(const char* file, int) {
return static_cast<void*>(LoadLibrary(file));
}
......@@ -108,15 +111,15 @@ std::string find_file_in_envs_with_intmd(
std::string get_nvcc_root_path() {
auto nvcc_root_path = find_file_in_envs_with_intmd({ENV_PATH}, NVCC_EXE);
if (nvcc_root_path.empty()) {
mgb_throw(
MegBrainError,
"nvcc not found. Add your nvcc to your environment Path");
return {};
} else {
auto idx = nvcc_root_path.rfind(PATH_SPLITER);
return nvcc_root_path.substr(0, idx + 1);
}
}
} // namespace
std::vector<std::string> mgb::get_cuda_include_path() {
#if MGB_CUDA
std::vector<std::string> paths;
......@@ -131,6 +134,7 @@ std::vector<std::string> mgb::get_cuda_include_path() {
// 2. use nvcc path
auto nvcc_path = get_nvcc_root_path();
if (!nvcc_path.empty()) {
auto cudart_header_path = nvcc_path + ".." + PATH_SPLITER + "include" +
PATH_SPLITER + "cuda_runtime.h";
//! double check path_to_nvcc/../include/cuda_runtime.h exists
......@@ -139,6 +143,7 @@ std::vector<std::string> mgb::get_cuda_include_path() {
paths.emplace_back(nvcc_path + "..");
paths.emplace_back(nvcc_path + ".." + PATH_SPLITER + "include");
}
}
// 3. use libcudart.so library path
char cuda_lib_path[PATH_MAX];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册