提交 244ba709 编写于 作者: alinag's avatar alinag

polish codes

上级 b437ff00
...@@ -54,6 +54,7 @@ class ScopedRNNBase { ...@@ -54,6 +54,7 @@ class ScopedRNNBase {
x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x)); x_descs_.emplace_back(x_desc_.descriptor<T>(dims_x, strides_x));
y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y)); y_descs_.emplace_back(y_desc_.descriptor<T>(dims_y, strides_y));
} }
#if CUDNN_VERSION >= 7201 #if CUDNN_VERSION >= 7201
if (!sequence_length.empty()) { if (!sequence_length.empty()) {
x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true, x_seq_desc_.descriptor<T>(seq_length_, batch_size_, input_size_, true,
...@@ -63,6 +64,7 @@ class ScopedRNNBase { ...@@ -63,6 +64,7 @@ class ScopedRNNBase {
sequence_length); sequence_length);
} }
#endif #endif
// ------------------- cudnn hx, hy, cx, cy descriptors---------- // ------------------- cudnn hx, hy, cx, cy descriptors----------
std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_, std::vector<int> dims_hx = {num_layers_ * numDirections, batch_size_,
hidden_size_}; hidden_size_};
...@@ -97,10 +99,13 @@ class ScopedRNNBase { ...@@ -97,10 +99,13 @@ class ScopedRNNBase {
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
cudnn_type)); cudnn_type));
#endif #endif
#if CUDNN_VERSION >= 7201
if (!sequence_length.empty()) { if (!sequence_length.empty()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNPaddingMode(
rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED));
} }
#endif
// ------------------- cudnn weights_size --------------------- // ------------------- cudnn weights_size ---------------------
size_t weights_size_; size_t weights_size_;
......
...@@ -190,7 +190,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -190,7 +190,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
__macro(cudnnRNNBackwardDataEx); \ __macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \ __macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx); __macro(cudnnRNNForwardInferenceEx);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_AFTER_R72(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
#if CUDNN_VERSION >= 7401 #if CUDNN_VERSION >= 7401
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册