diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 6011e3e9253c99b28cd821e8b2ff832944733074..3fd428033eebbf5edd0dc9f30e348df28250901f 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -7,6 +7,10 @@ if(NOT "$ENV{LIBRARY_PATH}" STREQUAL "") string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) endif() +if("${CUDNN_ROOT_DIR}" STREQUAL "" AND NOT "$ENV{CUDNN_ROOT_DIR}" STREQUAL "") + set(CUDNN_ROOT_DIR $ENV{CUDNN_ROOT_DIR}) +endif() + if(MGE_CUDA_USE_STATIC) find_library(CUDNN_LIBRARY NAMES libcudnn_static.a cudnn.lib diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake index 25d21aeabf1d8dcac965b67bbd82b334a50d72a9..7ac5b5b2cc109a1f5becdf663f5eee524ae43d34 100644 --- a/cmake/tensorrt.cmake +++ b/cmake/tensorrt.cmake @@ -2,6 +2,10 @@ if(NOT "$ENV{LIBRARY_PATH}" STREQUAL "") string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) endif() +if("${TRT_ROOT_DIR}" STREQUAL "" AND NOT "$ENV{TRT_ROOT_DIR}" STREQUAL "") + set(TRT_ROOT_DIR $ENV{TRT_ROOT_DIR}) +endif() + if(MGE_CUDA_USE_STATIC) find_library(TRT_LIBRARY NAMES libnvinfer_static.a nvinfer.lib diff --git a/dnn/src/CMakeLists.txt b/dnn/src/CMakeLists.txt index 430a6b946d1d22b470b8a2241d49a5aacae0003f..029f1bf8ed6040e1d23129fc5d798bd076913dd1 100644 --- a/dnn/src/CMakeLists.txt +++ b/dnn/src/CMakeLists.txt @@ -134,6 +134,7 @@ add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES}) target_link_libraries(megdnn PUBLIC opr_param_defs) if(MGE_WITH_CUDA) target_link_libraries(megdnn PRIVATE $) + target_include_directories(megdnn PRIVATE ${CUDNN_INCLUDE_DIR}) endif() if(MGE_WITH_ROCM) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c0db789c1e2a518eaaed0961be66449e3d5e3c7..5fddd520e1760888c8af518c16de1d6425a7680c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,6 +57,9 @@ endif() add_library(megbrain OBJECT ${SOURCES}) target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) +if(MGE_WITH_CUDA) + target_include_directories(megbrain PUBLIC ${TRT_INCLUDE_DIR}) +endif() target_include_directories(megbrain PUBLIC $ PRIVATE ${PROJECT_SOURCE_DIR}/third_party/midout/src