diff --git a/cmake/flags.cmake b/cmake/flags.cmake index a0069271252a4efad8189468c5e8b03678a737ef..e05cd52f473388d972cb4d88aa35b5270e4c3f52 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -74,13 +74,32 @@ endforeach() # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. +function(specify_cuda_arch cuda_version cuda_arch) + if(${cuda_version} VERSION_GREATER "8.0") + foreach(capability 60 61 62) + if(${cuda_arch} STREQUAL ${capability}) + list(APPEND __arch_flags " -gencode arch=compute_${cuda_arch},code=sm_${cuda_arch}") + endif() + endforeach() + elseif(${cuda_version} VERSION_GREATER "7.0") + foreach(capability 52 53) + if(${cuda_arch} STREQUAL ${capability}) + list(APPEND __arch_flags " -gencode arch=compute_${cuda_arch},code=sm_${cuda_arch}") + endif() + endforeach() + endif() +endfunction() + +# Common cuda architectures foreach(capability 30 35 50) list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}") endforeach() -if (CUDA_VERSION VERSION_GREATER "7.0") - list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52") -endif() +# Custom cuda architecture +set(CUDA_ARCH) +if(CUDA_ARCH) + specify_cuda_arch(${CUDA_VERSION} ${CUDA_ARCH}) +endif() set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})