未验证 提交 c71d79b1 编写于 作者: W wangchaochaohu 提交者: GitHub

[cuda11 support] change the CMakeLists to support the cuda11 (#27124)

上级 f7d08b7d
......@@ -16,7 +16,7 @@ else()
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
set(paddle_known_gpu_archs11 "35 50 52 60 61 70 75 80")
set(paddle_known_gpu_archs11 "52 60 61 70 75 80")
endif()
######################################################################################
......
......@@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git)
set(WARPCTC_TAG bc29dcfff07ced1c7a19a4ecee48e5ad583cef8e)
set(WARPCTC_TAG fc7f226b93758216a03b1be9d24593a12819b984)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
......
......@@ -28,7 +28,15 @@ function(CheckCompilerCXX11Flag)
endfunction()
CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if (WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
# safe_set_flag
#
# Set a compile flag only if compiler is support
......
......@@ -243,9 +243,10 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
ENDIF()
if(WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
include(external/cub) # download cub
list(APPEND third_party_deps extern_cub)
endif()
set(CUDAERROR_URL "http://paddlepaddledeps.bj.bcebos.com/cudaErrorMessage.tar.gz" CACHE STRING "" FORCE)
file_download_and_uncompress(${CUDAERROR_URL} "cudaerror") # download file cudaErrorMessage
endif(WITH_GPU)
......
......@@ -45,7 +45,9 @@ endif()
SET(OP_HEADER_DEPS xxhash executor)
if (WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif()
endif()
SET(OP_PREFETCH_DEPS "")
......
......@@ -41,9 +41,13 @@ detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_fo
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS memory cub)
set(TMPDEPS memory)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
set(TMPDEPS memory cub)
endif()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS ${TMPDEPS})
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS ${TMPDEPS})
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS})
else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
......
......@@ -9,7 +9,11 @@ function(math_library TARGET)
set(hip_srcs)
set(math_common_deps device_context framework_proto enforce)
if (WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
list(APPEND math_common_deps cub)
else()
list(APPEND math_common_deps)
endif()
endif()
set(multiValueArgs DEPS)
cmake_parse_arguments(math_library "${options}" "${oneValueArgs}"
......
include(operators)
if(WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
register_operators(DEPS cub)
else()
register_operators()
endif()
else()
register_operators()
endif()
......@@ -24,5 +28,9 @@ if(WITH_GPU)
endif()
if(WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub)
else()
nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor)
endif()
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册