提交 e6943017 编写于 作者: M Megvii Engine Team

fix(ops/jit): skip lookup include path when nvcc executable not found

GitOrigin-RevId: f5e7dce1c52fe099de8e90684bbe5b0566a5115a
上级 02c1a0c3
import functools import functools
import os
import platform
import subprocess
import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -152,3 +156,56 @@ def test_subgraph_jit_backward(): ...@@ -152,3 +156,56 @@ def test_subgraph_jit_backward():
y1 = x1 * x1 y1 = x1 * x1
y2 = mul(x2, x2) y2 = mul(x2, x2)
gm.backward(y2) gm.backward(y2)
@pytest.mark.skipif(
platform.system() != "Linux", reason="jit fusion is only available on Linux",
)
def test_subgraph_jit():
prog = """
import megengine
import numpy as np
from megengine.core.tensor.utils import subgraph_fn
# 3 * 4 * 5 > MEGDNN_MAX_NDIM
x_np = np.random.rand(3, 4, 5).astype("float32")
x1 = megengine.Tensor(x_np)
x2 = megengine.Tensor(x_np)
@subgraph_fn(
"Mul",
dtype=x1.dtype,
device=x1.device,
nr_inputs=2,
gopt_level=None,
jit_fusion=True,
custom_grad=True,
)
def mul(inputs, f, c):
x, y = inputs[0:2]
z = f("*", x, y)
(dz,) = yield (z,)
dx = f("*", dz, y)
dy = f("*", dz, x)
yield (dx, dy)
y, = mul(x1, x2)
# ensure execution
y.numpy()
"""
env = dict(os.environ)
if "PATH" in env:
# remove nvcc from environ["PATH"]
path = env["PATH"]
paths = path.split(os.pathsep)
paths = [
path
for path in paths
if not (os.path.isdir(path) and "nvcc" in os.listdir(path))
]
path = os.pathsep.join(paths)
env["PATH"] = path
# previous program may be stored in persistent cache
env["MGE_FASTRUN_CACHE_TYPE"] = "MEMORY"
subprocess.check_call([sys.executable, "-c", prog], env=env)
...@@ -17,6 +17,8 @@ using namespace mgb; ...@@ -17,6 +17,8 @@ using namespace mgb;
#include <unistd.h> #include <unistd.h>
#endif #endif
namespace {
#ifndef PATH_MAX #ifndef PATH_MAX
#define PATH_MAX 4096 #define PATH_MAX 4096
#endif #endif
...@@ -32,6 +34,7 @@ using namespace mgb; ...@@ -32,6 +34,7 @@ using namespace mgb;
#define PATH_SPLITER '\\' #define PATH_SPLITER '\\'
#define ENV_PATH "Path" #define ENV_PATH "Path"
#define NVCC_EXE "nvcc.exe" #define NVCC_EXE "nvcc.exe"
void* dlopen(const char* file, int) { void* dlopen(const char* file, int) {
return static_cast<void*>(LoadLibrary(file)); return static_cast<void*>(LoadLibrary(file));
} }
...@@ -108,15 +111,15 @@ std::string find_file_in_envs_with_intmd( ...@@ -108,15 +111,15 @@ std::string find_file_in_envs_with_intmd(
std::string get_nvcc_root_path() { std::string get_nvcc_root_path() {
auto nvcc_root_path = find_file_in_envs_with_intmd({ENV_PATH}, NVCC_EXE); auto nvcc_root_path = find_file_in_envs_with_intmd({ENV_PATH}, NVCC_EXE);
if (nvcc_root_path.empty()) { if (nvcc_root_path.empty()) {
mgb_throw( return {};
MegBrainError,
"nvcc not found. Add your nvcc to your environment Path");
} else { } else {
auto idx = nvcc_root_path.rfind(PATH_SPLITER); auto idx = nvcc_root_path.rfind(PATH_SPLITER);
return nvcc_root_path.substr(0, idx + 1); return nvcc_root_path.substr(0, idx + 1);
} }
} }
} // namespace
std::vector<std::string> mgb::get_cuda_include_path() { std::vector<std::string> mgb::get_cuda_include_path() {
#if MGB_CUDA #if MGB_CUDA
std::vector<std::string> paths; std::vector<std::string> paths;
...@@ -131,13 +134,15 @@ std::vector<std::string> mgb::get_cuda_include_path() { ...@@ -131,13 +134,15 @@ std::vector<std::string> mgb::get_cuda_include_path() {
// 2. use nvcc path // 2. use nvcc path
auto nvcc_path = get_nvcc_root_path(); auto nvcc_path = get_nvcc_root_path();
auto cudart_header_path = nvcc_path + ".." + PATH_SPLITER + "include" + if (!nvcc_path.empty()) {
PATH_SPLITER + "cuda_runtime.h"; auto cudart_header_path = nvcc_path + ".." + PATH_SPLITER + "include" +
//! double check path_to_nvcc/../include/cuda_runtime.h exists PATH_SPLITER + "cuda_runtime.h";
auto ret = check_file_exist(cudart_header_path.c_str(), F_OK); //! double check path_to_nvcc/../include/cuda_runtime.h exists
if (ret == 0) { auto ret = check_file_exist(cudart_header_path.c_str(), F_OK);
paths.emplace_back(nvcc_path + ".."); if (ret == 0) {
paths.emplace_back(nvcc_path + ".." + PATH_SPLITER + "include"); paths.emplace_back(nvcc_path + "..");
paths.emplace_back(nvcc_path + ".." + PATH_SPLITER + "include");
}
} }
// 3. use libcudart.so library path // 3. use libcudart.so library path
...@@ -161,4 +166,4 @@ std::vector<std::string> mgb::get_cuda_include_path() { ...@@ -161,4 +166,4 @@ std::vector<std::string> mgb::get_cuda_include_path() {
#else #else
mgb_throw(MegBrainError, "cuda disabled at compile time"); mgb_throw(MegBrainError, "cuda disabled at compile time");
#endif #endif
} }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册