提交 cda60311 编写于 作者: D Dang Qingqing

Fix compling with cuDNN v5

test=develop
上级 1d3e9bde
......@@ -36,15 +36,18 @@ endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op)
# warpctc_cudnn need cudnn 7 above
# warpctc_op needs cudnn 7 above
if (WITH_GPU)
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
# conv_fusion_op needs cudnn 7 above
if (NOT ${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
......@@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
namespace paddle {
namespace operators {
#if CUDNN_VERSION >= 7001
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
......@@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
};
#endif
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 7001
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
#endif
......@@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
endif(NOT WITH_DISTRIBUTE)
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册