diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 76e84130015dc5683d30e5881d3ffd480fa6a876..dd64cc327fc383937bc9a9d6e7daa0cec488e4cc 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -180,13 +180,13 @@ struct CudnnRNNCache { #if CUDNN_VERSION >= 6000 CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6( - handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, + handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); #else CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor( - rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - CUDNN_LINEAR_INPUT, + rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, CUDNN_DATA_FLOAT)); #endif