提交 aaa703d9 编写于 作者: Z zhhsplendid

Merge branch 'cudnn_dyload' of https://github.com/GaoWei8/Paddle into refine_api2, test=develop

...@@ -54,6 +54,8 @@ class ScopedRNNBase { ...@@ -54,6 +54,8 @@ class ScopedRNNBase {
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x)); x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y)); y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
} }
#if CUDNN_VERSION >= 7201
if (!sequence_length.empty()) { if (!sequence_length.empty()) {
x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true, x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true,
sequence_length); sequence_length);
...@@ -61,6 +63,7 @@ class ScopedRNNBase { ...@@ -61,6 +63,7 @@ class ScopedRNNBase {
hidden_size_ * numDirections, true, hidden_size_ * numDirections, true,
sequence_length); sequence_length);
} }
#endif
// ------------------- cudnn hx, hy, cx, cy descriptors---------- // ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_, std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
...@@ -96,10 +99,13 @@ class ScopedRNNBase { ...@@ -96,10 +99,13 @@ class ScopedRNNBase {
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type)); cudnn_type));
#endif #endif
#if CUDNN_VERSION >= 7201
if (!sequence_length.empty()) { if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
} }
#endif
// ------------------- cudnn weights_size --------------------- // ------------------- cudnn weights_size ---------------------
size_t weights_size_; size_t weights_size_;
...@@ -125,8 +131,10 @@ class ScopedRNNBase { ...@@ -125,8 +131,10 @@ class ScopedRNNBase {
} }
cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); } cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); }
cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); } cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); }
#if CUDNN_VERSION >= 7201
cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); } cudnnRNNDataDescriptor_t x_seq_desc() { return x_seq_desc_.desc(); }
cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); } cudnnRNNDataDescriptor_t y_seq_desc() { return y_seq_desc_.desc(); }
#endif
cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); } cudnnTensorDescriptor_t init_h_desc() { return init_h_desc_.desc(); }
cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); } cudnnTensorDescriptor_t init_c_desc() { return init_c_desc_.desc(); }
cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); } cudnnTensorDescriptor_t last_h_desc() { return last_h_desc_.desc(); }
...@@ -151,8 +159,10 @@ class ScopedRNNBase { ...@@ -151,8 +159,10 @@ class ScopedRNNBase {
platform::ScopedTensorDescriptor x_desc_; platform::ScopedTensorDescriptor x_desc_;
platform::ScopedTensorDescriptor y_desc_; platform::ScopedTensorDescriptor y_desc_;
#if CUDNN_VERSION >= 7201
platform::ScopedRNNTensorDescriptor x_seq_desc_; platform::ScopedRNNTensorDescriptor x_seq_desc_;
platform::ScopedRNNTensorDescriptor y_seq_desc_; platform::ScopedRNNTensorDescriptor y_seq_desc_;
#endif
platform::ScopedTensorDescriptor init_h_desc_; platform::ScopedTensorDescriptor init_h_desc_;
platform::ScopedTensorDescriptor init_c_desc_; platform::ScopedTensorDescriptor init_c_desc_;
platform::ScopedTensorDescriptor last_h_desc_; platform::ScopedTensorDescriptor last_h_desc_;
......
...@@ -294,6 +294,7 @@ class ScopedTensorDescriptor { ...@@ -294,6 +294,7 @@ class ScopedTensorDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor); DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
}; };
#if CUDNN_VERSION >= 7201
class ScopedRNNTensorDescriptor { class ScopedRNNTensorDescriptor {
public: public:
ScopedRNNTensorDescriptor() { ScopedRNNTensorDescriptor() {
...@@ -337,6 +338,7 @@ class ScopedRNNTensorDescriptor { ...@@ -337,6 +338,7 @@ class ScopedRNNTensorDescriptor {
cudnnRNNDataDescriptor_t desc_; cudnnRNNDataDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor); DISABLE_COPY_AND_ASSIGN(ScopedRNNTensorDescriptor);
}; };
#endif
class ScopedDropoutDescriptor { class ScopedDropoutDescriptor {
public: public:
......
...@@ -101,9 +101,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -101,9 +101,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnDropoutGetStatesSize); \ __macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \ __macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \ __macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnCreateRNNDescriptor); \ __macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \ __macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \ __macro(cudnnGetRNNWorkspaceSize); \
...@@ -112,11 +109,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -112,11 +109,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardData); \ __macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \ __macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \ __macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \ __macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx); __macro(cudnnSetTensorNdDescriptorEx);
...@@ -188,6 +180,19 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -188,6 +180,19 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
#if CUDNN_VERSION >= 7201
#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx);
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 7401 #if CUDNN_VERSION >= 7401
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \
__macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \ __macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册