From 18ebeec2accb31381e80554e0a0f60931cf0ad0f Mon Sep 17 00:00:00 2001 From: chenchaoxiu Date: Mon, 19 Dec 2016 16:26:15 +0800 Subject: [PATCH] Added support for cudnn v6 and cuda 8.0 --- paddle/cuda/src/hl_cuda_cudnn.cc | 44 ++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index 8cddf10d40c..c0c8b0e60db 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -175,11 +175,15 @@ void hl_cudnn_init(cudnnHandle_t* cudnn_handle, cudaStream_t stream) { << "PaddlePaddle Requirement: " << "(header v[2-3] with libcudnn v[2-3]) Or " << "(header v4 with libcudnn v4) Or " - << "(header v5 with libcudnn v5)."; + << "(header v5 with libcudnn v5) Or" + << "(header v6 with libcudnn v6)."; - CHECK(!(CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050)) + CHECK(!(CUDNN_VERSION < 6000 && CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050)) << "cudnn v5 requires cuda version >= 7.5"; + CHECK(!(CUDNN_VERSION >= 6000 && CUDA_VERSION < 8000)) + << "cudnn v6 requires cuda version >= 8.0"; + CHECK_CUDNN(dynload::cudnnCreate(cudnn_handle)); CHECK_CUDNN(dynload::cudnnSetStream(*cudnn_handle, stream)); @@ -610,6 +614,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, CHECK_CUDNN(dynload::cudnnCreateConvolutionDescriptor(&hl_conv->desc)); cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + +#if CUDNN_VERSION >= 6000 +#ifndef PADDLE_TYPE_DOUBLE + cudnnDataType_t data_type = CUDNN_DATA_FLOAT; +#else + cudnnDataType_t data_type = CUDNN_DATA_DOUBLE; +#endif + CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc, + padding_height, + padding_width, + stride_height, + stride_width, + 1, + 1, + mode, + data_type)); +#else CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc, padding_height, padding_width, @@ -618,6 +639,7 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv, 1, 1, mode)); +#endif hl_conv->input_image = image; hl_conv->filter = filter; @@ -645,6 +667,23 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv); cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + +#if CUDNN_VERSION >= 6000 +#ifndef PADDLE_TYPE_DOUBLE + cudnnDataType_t data_type = CUDNN_DATA_FLOAT; +#else + cudnnDataType_t data_type = CUDNN_DATA_DOUBLE; +#endif + CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc, + padding_height, + padding_width, + stride_height, + stride_width, + 1, + 1, + mode, + data_type)); +#else CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc, padding_height, padding_width, @@ -653,6 +692,7 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv, 1, 1, mode)); +#endif cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv; hl_conv->input_image = image; -- GitLab