cudnn.cmake 3.6 KB
Newer Older
1
if(NOT LITE_WITH_CUDA)
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
    return()
endif()

if(WIN32)
    set(CUDNN_ROOT ${CUDA_TOOLKIT_ROOT_DIR})
else(WIN32)
    set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT")
endif(WIN32)

find_path(CUDNN_INCLUDE_DIR cudnn.h
    PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include
    $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
    NO_DEFAULT_PATH
)

get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)

set(TARGET_ARCH "x86_64")
if(NOT ${CMAKE_SYSTEM_PROCESSOR})
    set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
endif()

list(APPEND CUDNN_CHECK_LIBRARY_DIRS
    ${CUDNN_ROOT}
    ${CUDNN_ROOT}/lib64
    ${CUDNN_ROOT}/lib
    ${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
29 30
    /usr/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/
    /usr/lib/${TARGET_ARCH}-linux-gnu/
Y
Yan Chunwei 已提交
31 32 33 34 35 36 37
    $ENV{CUDNN_ROOT}
    $ENV{CUDNN_ROOT}/lib64
    $ENV{CUDNN_ROOT}/lib
    /usr/lib
	${CUDA_TOOLKIT_ROOT_DIR}
	${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
	)
38 39 40 41 42 43 44 45

if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0))
    find_library(CUBLAS_LIBRARY  NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH)
    set(CUBLAS_LIBRARIES ${CUBLAS_LIBRARY})
else()
    set(CUBLAS_LIBRARIES ${CUDA_CUBLAS_LIBRARIES})
endif()

Y
Yan Chunwei 已提交
46 47 48 49 50 51 52 53 54 55 56
set(CUDNN_LIB_NAME "libcudnn.so")

if(WIN32)
# only support cudnn7
set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll")
endif(WIN32)

if(APPLE)
set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so")
endif(APPLE)

57
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME}
Y
Yan Chunwei 已提交
58 59
    PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
          NO_DEFAULT_PATH
60
    DOC "Path to cuDNN dynamic library.")
Y
Yan Chunwei 已提交
61 62 63 64 65 66 67 68 69 70 71

if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
    set(CUDNN_FOUND ON)
else()
    set(CUDNN_FOUND OFF)
endif()

if(CUDNN_FOUND)
    file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS)

    get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY)
72 73 74
    add_library(cudnn_static STATIC IMPORTED GLOBAL)
    set_property(TARGET cudnn_static PROPERTY IMPORTED_LOCATION
               "${CUDNN_LIB_PATH}/libcudnn_static.a")
Y
Yan Chunwei 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

    string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)"
        CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
    string(REGEX REPLACE "define CUDNN_VERSION +([0-9]+)" "\\1"
        CUDNN_VERSION "${CUDNN_VERSION}")

    if("${CUDNN_VERSION}" STREQUAL "2000")
        message(STATUS "Current cuDNN version is v2. ")
    else()
        string(REGEX MATCH "define CUDNN_MAJOR +([0-9]+)" CUDNN_MAJOR_VERSION
            "${CUDNN_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define CUDNN_MAJOR +([0-9]+)" "\\1"
            CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}")
        string(REGEX MATCH "define CUDNN_MINOR +([0-9]+)" CUDNN_MINOR_VERSION
            "${CUDNN_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define CUDNN_MINOR +([0-9]+)" "\\1"
            CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}")
        string(REGEX MATCH "define CUDNN_PATCHLEVEL +([0-9]+)"
            CUDNN_PATCHLEVEL_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define CUDNN_PATCHLEVEL +([0-9]+)" "\\1"
            CUDNN_PATCHLEVEL_VERSION "${CUDNN_PATCHLEVEL_VERSION}")

        if(NOT CUDNN_MAJOR_VERSION)
            set(CUDNN_VERSION "???")
        else()
            add_definitions("-DPADDLE_CUDNN_BINVER=\"${CUDNN_MAJOR_VERSION}\"")
            math(EXPR CUDNN_VERSION
                "${CUDNN_MAJOR_VERSION} * 1000 +
                 ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
        endif()

        message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. "
            "Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ")

    endif()
endif()