提交 62fe2f28 编写于 作者: C chengduoZH

follow comments

上级 16fc5e38
...@@ -12,6 +12,7 @@ function(op_library TARGET) ...@@ -12,6 +12,7 @@ function(op_library TARGET)
set(cu_srcs) set(cu_srcs)
set(cu_cc_srcs) set(cu_cc_srcs)
set(cudnn_cu_cc_srcs) set(cudnn_cu_cc_srcs)
set(CUDNN_FILE)
set(op_common_deps operator op_registry math_function) set(op_common_deps operator op_registry math_function)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -31,6 +32,10 @@ function(op_library TARGET) ...@@ -31,6 +32,10 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu)
endif() endif()
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
else() else()
foreach(src ${op_library_SRCS}) foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$") if (${src} MATCHES ".*\\.cu$")
...@@ -103,7 +108,7 @@ function(op_library TARGET) ...@@ -103,7 +108,7 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for CUDNN # pybind USE_OP_DEVICE_KERNEL for CUDNN
list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
if (${cudnn_cu_cc_srcs_len} GREATER 0) if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif() endif()
...@@ -161,38 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) ...@@ -161,38 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader) op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind # Regist multiple Kernel to pybind
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS op_library(conv_op DEPS vol2col depthwise_conv)
vol2col depthwise_conv)
op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
else() else()
op_library(conv_op SRCS conv_op.cc DEPS vol2col) op_library(conv_op DEPS vol2col)
op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif() endif()
op_library(pool_op DEPS pooling)
op_library(conv_transpose_op DEPS vol2col)
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry) cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
op_library(fill_constant_batch_size_like_op DEPS batch_size_like)
op_library(fill_constant_batch_size_like_op op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op)
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op)
DEPS batch_size_like)
op_library(uniform_random_batch_size_like_op
SRCS uniform_random_batch_size_like_op.cc
DEPS batch_size_like uniform_random_op)
op_library(gaussian_random_batch_size_like_op
SRCS gaussian_random_batch_size_like_op.cc
DEPS batch_size_like gaussian_random_op)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册