未验证 提交 858ffa0c 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the dropout setting when not initialized in rnn_op. (#28561)

test=develop
上级 f78211d0
...@@ -89,15 +89,16 @@ class RNNDescriptors { ...@@ -89,15 +89,16 @@ class RNNDescriptors {
// ------------------- cudnn dropout descriptors --------------------- // ------------------- cudnn dropout descriptors ---------------------
size_t state_size; size_t state_size;
if (!is_test_ && !dropout_state->IsInitialized()) { bool is_initialized = dropout_state->IsInitialized();
if (!is_test_ && !is_initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)}, dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place); place);
} }
dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(), dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_,
dropout_prob_, is_test_ ? nullptr : dropout_state, is_test_ ? nullptr : dropout_state, seed_,
seed_, state_size); state_size);
// ------------------- cudnn rnn descriptors --------------------- // ------------------- cudnn rnn descriptors ---------------------
#if CUDNN_VERSION >= 6000 #if CUDNN_VERSION >= 6000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册