nccl.cmake 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
if(NOT WITH_GPU)
    return()
endif()

# Now we don't support NCCL on windows
if(WIN32)
    return()
endif()

set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
    PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
    $ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
    NO_DEFAULT_PATH
)

if(WITH_NCCL)
    file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS)

    string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)"
        NCCL_VERSION "${NCCL_VERSION_FILE_CONTENTS}")
    string(REGEX REPLACE "define NCCL_VERSION_CODE +([0-9]+)" "\\1"
        NCCL_VERSION "${NCCL_VERSION}")

    if("${NCCL_VERSION}" GREATER "2000")
        message(STATUS "Current NCCL header is ${NCCL_INCLUDE_DIR}/nccl.h. "
            "Current NCCL version is v${NCCL_VERSION}. ")
    else()
        # in old version nccl, it may not define NCCL_VERSION_CODE
        string(REGEX MATCH "define NCCL_MAJOR +([0-9]+)" NCCL_MAJOR_VERSION
            "${NCCL_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define NCCL_MAJOR +([0-9]+)" "\\1"
            NCCL_MAJOR_VERSION "${NCCL_MAJOR_VERSION}")
        string(REGEX MATCH "define NCCL_MINOR +([0-9]+)" NCCL_MINOR_VERSION
            "${NCCL_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define NCCL_MINOR +([0-9]+)" "\\1"
            NCCL_MINOR_VERSION "${NCCL_MINOR_VERSION}")
        string(REGEX MATCH "define NCCL_PATCH +([0-9]+)"
            NCCL_PATCH_VERSION "${NCCL_VERSION_FILE_CONTENTS}")
        string(REGEX REPLACE "define NCCL_PATCH +([0-9]+)" "\\1"
            NCCL_PATCH_VERSION "${NCCL_PATCH_VERSION}")

        if(NOT NCCL_MAJOR_VERSION)
            set(NCCL_VERSION "0")
        else()
            math(EXPR NCCL_VERSION
                "${NCCL_MAJOR_VERSION} * 1000 +
                 ${NCCL_MINOR_VERSION} * 100 + ${NCCL_PATCH_VERSION}")
        endif()
        add_definitions("-DNCCL_VERSION_CODE=$NCCL_VERSION")

        message(STATUS "Current NCCL header is ${NCCL_INCLUDE_DIR}/nccl.h. "
            "Current NCCL version is v${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}.${NCCL_PATCH_VERSION} ")
    endif()
endif()