diff --git a/paddle/fluid/operators/cudnn_lstm_cache.h b/paddle/fluid/operators/cudnn_lstm_cache.h index 4b46e2b475e8bd59c59744bdfde7bfb1248bc99a..3181e4b1d990b775f2c80a5d13f391886f83b080 100644 --- a/paddle/fluid/operators/cudnn_lstm_cache.h +++ b/paddle/fluid/operators/cudnn_lstm_cache.h @@ -54,6 +54,8 @@ class ScopedRNNBase { x_descs_.emplace_back(x_desc_.descriptor(dims_x, strides_x)); y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); } + +#if CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { x_seq_desc_.descriptor(seq_length_, batch_size_, input_size_, true, sequence_length); @@ -61,6 +63,7 @@ class ScopedRNNBase { hidden_size_ * numDirections, true, sequence_length); } +#endif // ------------------- cudnn hx, hy, cx, cy descriptors---------- std::vector dims_hx = {num_layers_ * numDirections, batch_size_, @@ -96,10 +99,13 @@ class ScopedRNNBase { is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, cudnn_type)); #endif + +#if CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); } +#endif // ------------------- cudnn weights_size --------------------- size_t weights_size_; @@ -125,8 +131,10 @@ class ScopedRNNBase { } cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); } cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); } +#if CUDNN_VERSION >= 7201 cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); } cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); } +#endif cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } @@ -151,8 +159,10 @@ class ScopedRNNBase { platform::ScopedTensorDescriptor x_desc_; platform::ScopedTensorDescriptor y_desc_; +#if CUDNN_VERSION >= 7201 platform::ScopedRNNTensorDescriptor x_seq_desc_; platform::ScopedRNNTensorDescriptor y_seq_desc_; +#endif platform::ScopedTensorDescriptor init_h_desc_; platform::ScopedTensorDescriptor init_c_desc_; platform::ScopedTensorDescriptor last_h_desc_; diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index bb4c2a89f6fa5e531aa322b69218cf58d3e94285..4b9c5c429dabc32fad6f05e4f066ab063057e733 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -294,6 +294,7 @@ class ScopedTensorDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor); }; +#if CUDNN_VERSION >= 7201 class ScopedRNNTensorDescriptor { public: ScopedRNNTensorDescriptor() { @@ -337,6 +338,7 @@ class ScopedRNNTensorDescriptor { cudnnRNNDataDescriptor_t desc_; DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor); }; +#endif class ScopedDropoutDescriptor { public: diff --git a/paddle/fluid/platform/dynload/cudnn.cc b/paddle/fluid/platform/dynload/cudnn.cc index 44a03d6f14a3ba07d73cfbc944d8db9601394103..1166dc5e4ad93fa23ef00623de6777b78b56ea09 100644 --- a/paddle/fluid/platform/dynload/cudnn.cc +++ b/paddle/fluid/platform/dynload/cudnn.cc @@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP); CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP); #endif +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7 +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DEFINE_WRAP); +#endif + #ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7 CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); #endif diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 7e85cb57f339331d5dd4233c2cad562c56d1d3af..fba41417648ba606727d00e71f48766f47479989 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -101,9 +101,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnDropoutGetStatesSize); \ __macro(cudnnSetDropoutDescriptor); \ __macro(cudnnRestoreDropoutDescriptor); \ - __macro(cudnnCreateRNNDataDescriptor); \ - __macro(cudnnDestroyRNNDataDescriptor); \ - __macro(cudnnSetRNNDataDescriptor); \ __macro(cudnnCreateRNNDescriptor); \ __macro(cudnnGetRNNParamsSize); \ __macro(cudnnGetRNNWorkspaceSize); \ @@ -112,11 +109,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnRNNBackwardData); \ __macro(cudnnRNNBackwardWeights); \ __macro(cudnnRNNForwardInference); \ - __macro(cudnnRNNForwardTrainingEx); \ - __macro(cudnnSetRNNPaddingMode); \ - __macro(cudnnRNNBackwardDataEx); \ - __macro(cudnnRNNBackwardWeightsEx); \ - __macro(cudnnRNNForwardInferenceEx); \ __macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyRNNDescriptor); \ __macro(cudnnSetTensorNdDescriptorEx); @@ -188,6 +180,19 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +#if CUDNN_VERSION >= 7201 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ + __macro(cudnnCreateRNNDataDescriptor); \ + __macro(cudnnDestroyRNNDataDescriptor); \ + __macro(cudnnSetRNNDataDescriptor); \ + __macro(cudnnSetRNNPaddingMode); \ + __macro(cudnnRNNForwardTrainingEx); \ + __macro(cudnnRNNBackwardDataEx); \ + __macro(cudnnRNNBackwardWeightsEx); \ + __macro(cudnnRNNForwardInferenceEx); +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + #if CUDNN_VERSION >= 7401 #define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \ __macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \