未验证 提交 858ffa0c 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the dropout setting when not initialized in rnn_op. (#28561)

test=develop
上级 f78211d0
......@@ -89,15 +89,16 @@ class RNNDescriptors {
// ------------------- cudnn dropout descriptors ---------------------
size_t state_size;
if (!is_test_ && !dropout_state->IsInitialized()) {
bool is_initialized = dropout_state->IsInitialized();
if (!is_test_ && !is_initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size));
dropout_state->mutable_data<uint8_t>({static_cast<int64_t>(state_size)},
place);
}
dropout_desc_.descriptor(handle, place, dropout_state->IsInitialized(),
dropout_prob_, is_test_ ? nullptr : dropout_state,
seed_, state_size);
dropout_desc_.descriptor(handle, place, is_initialized, dropout_prob_,
is_test_ ? nullptr : dropout_state, seed_,
state_size);
// ------------------- cudnn rnn descriptors ---------------------
#if CUDNN_VERSION >= 6000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册