提交 18ebeec2 编写于 作者: C chenchaoxiu

Added support for cudnn v6 and cuda 8.0

上级 8e25fbb9
......@@ -175,11 +175,15 @@ void hl_cudnn_init(cudnnHandle_t* cudnn_handle, cudaStream_t stream) {
<< "PaddlePaddle Requirement: "
<< "(header v[2-3] with libcudnn v[2-3]) Or "
<< "(header v4 with libcudnn v4) Or "
<< "(header v5 with libcudnn v5).";
<< "(header v5 with libcudnn v5) Or"
<< "(header v6 with libcudnn v6).";
CHECK(!(CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
CHECK(!(CUDNN_VERSION < 6000 && CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
<< "cudnn v5 requires cuda version >= 7.5";
CHECK(!(CUDNN_VERSION >= 6000 && CUDA_VERSION < 8000))
<< "cudnn v6 requires cuda version >= 8.0";
CHECK_CUDNN(dynload::cudnnCreate(cudnn_handle));
CHECK_CUDNN(dynload::cudnnSetStream(*cudnn_handle, stream));
......@@ -610,6 +614,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
CHECK_CUDNN(dynload::cudnnCreateConvolutionDescriptor(&hl_conv->desc));
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
#if CUDNN_VERSION >= 6000
#ifndef PADDLE_TYPE_DOUBLE
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
padding_height,
padding_width,
stride_height,
stride_width,
1,
1,
mode,
data_type));
#else
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
padding_height,
padding_width,
......@@ -618,6 +639,7 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
1,
1,
mode));
#endif
hl_conv->input_image = image;
hl_conv->filter = filter;
......@@ -645,6 +667,23 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
#if CUDNN_VERSION >= 6000
#ifndef PADDLE_TYPE_DOUBLE
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
padding_height,
padding_width,
stride_height,
stride_width,
1,
1,
mode,
data_type));
#else
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
padding_height,
padding_width,
......@@ -653,6 +692,7 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
1,
1,
mode));
#endif
cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
hl_conv->input_image = image;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册