From 858ffa0c8b6ff6c10b7f62a0a47d56fa7e37362f Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Wed, 18 Nov 2020 13:04:10 +0800 Subject: [PATCH] Fix the dropout setting when not initialized in rnn_op. (#28561) test=develop --- paddle/fluid/operators/rnn_op.cu.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/rnn_op.cu.cc b/paddle/fluid/operators/rnn_op.cu.cc index 568db797223..f38bfd59688 100644 --- a/paddle/fluid/operators/rnn_op.cu.cc +++ b/paddle/fluid/operators/rnn_op.cu.cc @@ -89,15 +89,16 @@ class RNNDescriptors { // ------------------- cudnn dropout descriptors --------------------- size_t state_size; - if (!is_test_ && !dropout_state->IsInitialized()) { + bool is_initialized = dropout_state->IsInitialized(); + if (!is_test_ && !is_initialized) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); dropout_state->mutable_data({static_cast(state_size)}, place); } - dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(), - dropout_prob_, is_test_ ? nullptr : dropout_state, - seed_, state_size); + dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_, + is_test_ ? nullptr : dropout_state, seed_, + state_size); // ------------------- cudnn rnn descriptors --------------------- #if CUDNN_VERSION >= 6000 -- GitLab