提交 16fc5e38 编写于 作者: C chengduoZH

refine cmake for cudnn

上级 95ea54fd
...@@ -11,6 +11,7 @@ function(op_library TARGET) ...@@ -11,6 +11,7 @@ function(op_library TARGET)
set(cc_srcs) set(cc_srcs)
set(cu_srcs) set(cu_srcs)
set(cu_cc_srcs) set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(op_common_deps operator op_registry math_function) set(op_common_deps operator op_registry math_function)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -34,6 +35,8 @@ function(op_library TARGET) ...@@ -34,6 +35,8 @@ function(op_library TARGET)
foreach(src ${op_library_SRCS}) foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$") if (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src}) list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$") elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src}) list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$") elseif(${src} MATCHES ".*\\.cc$")
...@@ -54,7 +57,7 @@ function(op_library TARGET) ...@@ -54,7 +57,7 @@ function(op_library TARGET)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif() endif()
if (WITH_GPU) if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps}) ${op_common_deps})
else() else()
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
...@@ -98,6 +101,12 @@ function(op_library TARGET) ...@@ -98,6 +101,12 @@ function(op_library TARGET)
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (${cudnn_cu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP # pybind USE_OP
if (${pybind_flag} EQUAL 0) if (${pybind_flag} EQUAL 0)
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
...@@ -159,21 +168,16 @@ op_library(create_reader_op DEPS reader) ...@@ -159,21 +168,16 @@ op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind # Regist multiple Kernel to pybind
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
vol2col depthwise_conv) vol2col depthwise_conv)
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function) op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col) conv_transpose_cudnn_op.cu.cc DEPS vol2col)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n")
else() else()
op_library(conv_op SRCS conv_op.cc DEPS vol2col) op_library(conv_op SRCS conv_op.cc DEPS vol2col)
op_library(pool_op SRCS pool_op.cc DEPS pooling) op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col) op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif() endif()
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry) cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册