From e0e986273ffdda827f4f5e39d3e455365be0cff7 Mon Sep 17 00:00:00 2001 From: wawltor Date: Fri, 15 Jan 2021 11:19:12 +0800 Subject: [PATCH] Cherrypick fix rnn batch size diff (#30462) * fix the rnn mask memory bug for out of read * update the code for the rnn --- paddle/fluid/operators/rnn_op.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index 253765bb419..b993f5ac174 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -960,9 +960,10 @@ class RNNCPUKernel : public framework::OpKernel { if (has_seq_length) { sequence_length = ctx.Input("SequenceLength"); } - if (!dropout_mask->IsInitialized()) { - dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); + if (dropout_mask->IsInitialized()) { + if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); } + dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); // init the output and allocate the memory output->mutable_data(ctx.GetPlace()); -- GitLab