提交 968dd3c0 编写于 作者: L liuhongyu

add cudnn 5 support; test=develop

上级 50a69852
...@@ -177,11 +177,20 @@ struct CudnnRNNCache { ...@@ -177,11 +177,20 @@ struct CudnnRNNCache {
seed_)); seed_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
#if CUDNN_VERSION >= 6000
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6( CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
#else
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_DATA_FLOAT));
#endif
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_)); CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
......
...@@ -125,8 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -125,8 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardWeights); \ __macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \ __macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \ __macro(cudnnDestroyRNNDescriptor);
__macro(cudnnSetRNNDescriptor_v6);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
...@@ -165,6 +164,13 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -165,6 +164,13 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
// APIs in R6
#if CUDNN_VERSION >= 6000
#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \
__macro(cudnnSetRNNDescriptor_v6);
CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \ __macro(cudnnSetConvolutionGroupCount); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册