提交 cda60311 编写于 作者: D Dang Qingqing

Fix compling with cuDNN v5

test=develop
上级 1d3e9bde
...@@ -36,15 +36,18 @@ endif() ...@@ -36,15 +36,18 @@ endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op) register_operators(EXCLUDES warpctc_op conv_fusion_op)
# warpctc_cudnn need cudnn 7 above # warpctc_op needs cudnn 7 above
if (WITH_GPU) if (WITH_GPU)
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
op_library(conv_fusion_op) # conv_fusion_op needs cudnn 7 above
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") if (NOT ${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
......
...@@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search); ...@@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#if CUDNN_VERSION >= 7001
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
...@@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
}; };
#endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDNN_VERSION >= 7001
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>, REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>); ops::CUDNNConvFusionOpKernel<double>);
#endif
...@@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE) ...@@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
endif(NOT WITH_DISTRIBUTE) endif(NOT WITH_DISTRIBUTE)
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184 list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册