diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index 253765bb41940a5f62874df95e8a110b5d853b3d..b993f5ac17479544e127f669c94ce0606ab47399 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -960,9 +960,10 @@ class RNNCPUKernel : public framework::OpKernel { if (has_seq_length) { sequence_length = ctx.Input("SequenceLength"); } - if (!dropout_mask->IsInitialized()) { - dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); + if (dropout_mask->IsInitialized()) { + if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); } + dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); // init the output and allocate the memory output->mutable_data(ctx.GetPlace());